Optimize DataJuicer Agent doc & linter (#30)
This commit is contained in:
@@ -16,17 +16,23 @@ from agentscope.tool import Toolkit
|
||||
from .dj_helpers import execute_safe_command
|
||||
from .router_helpers import agent_to_tool
|
||||
from .dj_helpers import query_dj_operators
|
||||
from .dj_dev_helpers import get_basic_files, get_operator_example, configure_data_juicer_path
|
||||
from .dj_dev_helpers import (
|
||||
get_basic_files,
|
||||
get_operator_example,
|
||||
configure_data_juicer_path,
|
||||
)
|
||||
from .mcp_helpers import get_mcp_toolkit
|
||||
|
||||
def create_toolkit(tools: List[str]):
|
||||
|
||||
def create_toolkit(tools: List[AgentBase]):
|
||||
# Create toolkit and register tools
|
||||
toolkit = Toolkit()
|
||||
for tool in tools:
|
||||
toolkit.register_tool_function(tool)
|
||||
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
|
||||
# DJ Agent tools
|
||||
dj_tools = [
|
||||
execute_safe_command,
|
||||
@@ -50,10 +56,12 @@ mcp_tools = [
|
||||
write_text_file,
|
||||
]
|
||||
|
||||
|
||||
def agents2toolkit(agents: List[AgentBase]):
|
||||
tools = [agent_to_tool(agent) for agent in agents]
|
||||
return create_toolkit(tools)
|
||||
|
||||
|
||||
dj_toolkit = create_toolkit(dj_tools)
|
||||
dj_dev_toolkit = create_toolkit(dj_dev_tools)
|
||||
|
||||
@@ -71,7 +79,6 @@ __all__ = [
|
||||
"dj_tools",
|
||||
"dj_dev_tools",
|
||||
"mcp_tools",
|
||||
"all_tools",
|
||||
"agents2toolkit",
|
||||
"dj_toolkit",
|
||||
"dj_dev_toolkit",
|
||||
@@ -85,4 +92,4 @@ __all__ = [
|
||||
"get_basic_files",
|
||||
"get_operator_example",
|
||||
"configure_data_juicer_path",
|
||||
]
|
||||
]
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
"""
|
||||
DataJuicer Development Tools
|
||||
|
||||
Tools for developing DataJuicer operators, including access to basic documentation
|
||||
and example code for different operator types.
|
||||
Tools for developing DataJuicer operators, including access to basic
|
||||
documentation and example code for different operator types.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -23,7 +23,8 @@ BASIC_LIST_RELATIVE = [
|
||||
def get_basic_files() -> ToolResponse:
|
||||
"""Get basic DataJuicer development files content.
|
||||
|
||||
Returns the content of essential files needed for DJ operator development:
|
||||
Returns the content of essential files needed for DJ operator
|
||||
development:
|
||||
- base_op.py: Base operator class
|
||||
- DeveloperGuide.md: English developer guide
|
||||
- DeveloperGuide_ZH.md: Chinese developer guide
|
||||
@@ -31,19 +32,23 @@ def get_basic_files() -> ToolResponse:
|
||||
Returns:
|
||||
ToolResponse: Combined content of all basic development files
|
||||
"""
|
||||
|
||||
global DATA_JUICER_PATH, BASIC_LIST_RELATIVE
|
||||
if DATA_JUICER_PATH is None:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"DATA_JUICER_PATH is not configured. Please ask the "
|
||||
"user to provide the DATA_JUICER_PATH"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
combined_content = "# DataJuicer Operator Development Basic Files\n\n"
|
||||
comb_content = "# DataJuicer Operator Development Basic Files\n\n"
|
||||
|
||||
for relative_path in BASIC_LIST_RELATIVE:
|
||||
file_path = os.path.join(DATA_JUICER_PATH, relative_path)
|
||||
@@ -52,20 +57,21 @@ def get_basic_files() -> ToolResponse:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
filename = os.path.basename(file_path)
|
||||
combined_content += f"## {filename}\n\n"
|
||||
combined_content += (
|
||||
f"```{'python' if filename.endswith('.py') else 'markdown'}\n"
|
||||
)
|
||||
combined_content += content
|
||||
combined_content += "\n```\n\n"
|
||||
file_n = os.path.basename(file_path)
|
||||
comb_content += f"## {file_n}\n\n```"
|
||||
flag = "python" if file_n.endswith(".py") else "markdown"
|
||||
comb_content += f"{flag}\n"
|
||||
comb_content += content
|
||||
comb_content += "\n```\n\n"
|
||||
except Exception as e:
|
||||
combined_content += (
|
||||
comb_content += (
|
||||
f"## {os.path.basename(file_path)} (Read Failed)\n"
|
||||
)
|
||||
combined_content += f"Error: {str(e)}\n\n"
|
||||
comb_content += f"Error: {str(e)}\n\n"
|
||||
|
||||
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
|
||||
return ToolResponse(
|
||||
content=[TextBlock(type="text", text=comb_content)],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
@@ -73,32 +79,41 @@ def get_basic_files() -> ToolResponse:
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error occurred while getting basic files: {str(e)}",
|
||||
)
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def get_operator_example(
|
||||
requirement_description: str, limit: int = 2
|
||||
requirement_description: str,
|
||||
limit: int = 2,
|
||||
) -> ToolResponse:
|
||||
"""Get example operators based on requirement description using dynamic search.
|
||||
"""Get example operators based on requirement description using
|
||||
dynamic search.
|
||||
|
||||
Args:
|
||||
requirement_description (str): Natural language description of the operator requirement
|
||||
limit (int): Maximum number of example operators to return (default: 2)
|
||||
requirement_description (str): Natural language description of
|
||||
the operator requirement
|
||||
limit (int): Maximum number of example operators to return
|
||||
(default: 2)
|
||||
|
||||
Returns:
|
||||
ToolResponse: Example operator code and test files based on the requirement
|
||||
ToolResponse: Example operator code and test files based on
|
||||
the requirement
|
||||
"""
|
||||
|
||||
global DATA_JUICER_PATH
|
||||
if DATA_JUICER_PATH is None:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"DATA_JUICER_PATH is not configured. Please ask the "
|
||||
"user to provide the DATA_JUICER_PATH"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -108,49 +123,56 @@ async def get_operator_example(
|
||||
# Query relevant operators using the requirement description
|
||||
# Use retrieval mode from environment variable if set
|
||||
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
|
||||
tool_names = await retrieve_ops(requirement_description, limit=limit, mode=retrieval_mode)
|
||||
tool_names = await retrieve_ops(
|
||||
requirement_description,
|
||||
limit=limit,
|
||||
mode=retrieval_mode,
|
||||
)
|
||||
|
||||
if not tool_names:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"No relevant operators found for requirement: {requirement_description}\n"
|
||||
f"Please try with more specific keywords or check if DATA_JUICER_PATH is properly configured.",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"No relevant operators found for requirement: "
|
||||
f"{requirement_description}\n"
|
||||
"Please try with more specific keywords or "
|
||||
"check if DATA_JUICER_PATH is properly "
|
||||
"configured."
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
combined_content = (
|
||||
comb_content = (
|
||||
f"# Dynamic Operator Examples for: {requirement_description}\n\n"
|
||||
)
|
||||
combined_content += (
|
||||
comb_content += (
|
||||
f"Found {len(tool_names)} relevant operators (limit: {limit})\n\n"
|
||||
)
|
||||
|
||||
# Process each found operator
|
||||
for i, tool_name in enumerate(tool_names[:limit]):
|
||||
combined_content += f"## {i+1}. {tool_name}\n\n"
|
||||
comb_content += f"## {i+1}. {tool_name}\n\n"
|
||||
|
||||
op_type = tool_name.split("_")[-1]
|
||||
|
||||
operator_path = f"data_juicer/ops/{op_type}/{tool_name}.py"
|
||||
|
||||
# Try to find operator source file
|
||||
|
||||
full_path = os.path.join(DATA_JUICER_PATH, operator_path)
|
||||
if os.path.exists(full_path):
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
operator_code = f.read()
|
||||
|
||||
combined_content += f"### Source Code\n"
|
||||
combined_content += "```python\n"
|
||||
combined_content += operator_code
|
||||
combined_content += "\n```\n\n"
|
||||
comb_content += "### Source Code\n"
|
||||
comb_content += "```python\n"
|
||||
comb_content += operator_code
|
||||
comb_content += "\n```\n\n"
|
||||
else:
|
||||
combined_content += (
|
||||
f"**Note:** Source code file not found for `{tool_name}`.\n\n"
|
||||
)
|
||||
comb_content += "**Note:** Source code file not found for"
|
||||
comb_content += f" `{tool_name}`.\n\n"
|
||||
|
||||
test_path = f"tests/ops/{op_type}/test_{tool_name}.py"
|
||||
|
||||
@@ -159,36 +181,43 @@ async def get_operator_example(
|
||||
with open(full_test_path, "r", encoding="utf-8") as f:
|
||||
test_code = f.read()
|
||||
|
||||
combined_content += f"### Test Code\n"
|
||||
combined_content += f"**File Path:** `{test_path}`\n\n"
|
||||
combined_content += "```python\n"
|
||||
combined_content += test_code
|
||||
combined_content += "\n```\n\n"
|
||||
comb_content += "### Test Code\n"
|
||||
comb_content += f"**File Path:** `{test_path}`\n\n"
|
||||
comb_content += "```python\n"
|
||||
comb_content += test_code
|
||||
comb_content += "\n```\n\n"
|
||||
|
||||
else:
|
||||
combined_content += (
|
||||
comb_content += (
|
||||
f"**Note:** Test file not found for `{tool_name}`.\n\n"
|
||||
)
|
||||
|
||||
combined_content += "---\n\n"
|
||||
comb_content += "---\n\n"
|
||||
|
||||
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
|
||||
return ToolResponse(
|
||||
content=[TextBlock(type="text", text=comb_content)],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error occurred while getting operator examples: {str(e)}\n"
|
||||
f"Please check the requirement description and try again.",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"Error occurred while getting operator examples: "
|
||||
f"{str(e)}\n"
|
||||
"Please check the requirement description and try "
|
||||
"again."
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
|
||||
"""Configure DataJuicer path.
|
||||
If the user provides the data_juicer_path, please use this method to configure it.
|
||||
If the user provides the data_juicer_path, please use this method to
|
||||
configure it.
|
||||
|
||||
Args:
|
||||
data_juicer_path (str): Path to DataJuicer installation
|
||||
@@ -196,8 +225,9 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
|
||||
Returns:
|
||||
ToolResponse: Configuration result
|
||||
"""
|
||||
|
||||
global DATA_JUICER_PATH
|
||||
|
||||
|
||||
data_juicer_path = os.path.expanduser(data_juicer_path)
|
||||
|
||||
try:
|
||||
@@ -206,9 +236,12 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Specified DataJuicer path does not exist: {data_juicer_path}",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"Specified DataJuicer path does not exist: "
|
||||
f"{data_juicer_path}"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Update global DATA_JUICER_PATH
|
||||
@@ -218,9 +251,12 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"DataJuicer path has been updated to: {DATA_JUICER_PATH}",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"DataJuicer path has been updated to: ",
|
||||
f"{DATA_JUICER_PATH}",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -228,7 +264,10 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error occurred while configuring DataJuicer path: {str(e)}",
|
||||
)
|
||||
]
|
||||
text=(
|
||||
"Error occurred while configuring DataJuicer path: "
|
||||
f"{str(e)}"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from agentscope.message import TextBlock
|
||||
from agentscope.tool import ToolResponse
|
||||
from .op_manager.op_retrieval import retrieve_ops
|
||||
|
||||
# Load tool information for formatting
|
||||
TOOLS_INFO_PATH = osp.join(osp.dirname(__file__), "op_manager", "dj_funcs_all.json")
|
||||
TOOLS_INFO_PATH = osp.join(
|
||||
osp.dirname(__file__),
|
||||
"op_manager",
|
||||
"dj_funcs_all.json",
|
||||
)
|
||||
|
||||
|
||||
def _load_tools_info():
|
||||
"""Load tools information from JSON file or create it if not exists"""
|
||||
@@ -17,30 +22,35 @@ def _load_tools_info():
|
||||
return json.loads(f.read())
|
||||
else:
|
||||
from .op_manager.create_dj_func_info import dj_func_info
|
||||
|
||||
with open(TOOLS_INFO_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(dj_func_info, f)
|
||||
return dj_func_info
|
||||
|
||||
|
||||
def _format_tool_names_to_class_entries(tool_names):
|
||||
"""Convert tool names list to formatted class entries string"""
|
||||
if not tool_names:
|
||||
return ""
|
||||
|
||||
|
||||
tools_info = _load_tools_info()
|
||||
|
||||
|
||||
# Create a mapping from class_name to tool info for quick lookup
|
||||
tools_map = {tool['class_name']: tool for tool in tools_info}
|
||||
|
||||
tools_map = {tool["class_name"]: tool for tool in tools_info}
|
||||
|
||||
formatted_entries = []
|
||||
for i, tool_name in enumerate(tool_names):
|
||||
if tool_name in tools_map:
|
||||
tool_info = tools_map[tool_name]
|
||||
class_entry = f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
|
||||
class_entry = (
|
||||
f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
|
||||
)
|
||||
class_entry += "\n" + tool_info["arguments"]
|
||||
formatted_entries.append(class_entry)
|
||||
|
||||
|
||||
return "\n".join(formatted_entries)
|
||||
|
||||
|
||||
async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
"""Query DataJuicer operators by natural language description.
|
||||
|
||||
@@ -52,26 +62,33 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
limit (int): Maximum number of operators to return (default: 20)
|
||||
|
||||
Returns:
|
||||
ToolResponse: Tool response containing matched operators with names, descriptions, and parameters
|
||||
ToolResponse: Tool response containing matched operators with names,
|
||||
descriptions, and parameters
|
||||
"""
|
||||
|
||||
try:
|
||||
# Retrieve operator names using existing functionality with limit
|
||||
# Use retrieval mode from environment variable if set
|
||||
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
|
||||
tool_names = await retrieve_ops(query, limit=limit, mode=retrieval_mode)
|
||||
tool_names = await retrieve_ops(
|
||||
query,
|
||||
limit=limit,
|
||||
mode=retrieval_mode,
|
||||
)
|
||||
|
||||
if not tool_names:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"No matching DataJuicer operators found for query: {query}\n"
|
||||
f"Suggestions:\n"
|
||||
f"1. Use more specific keywords like 'text filter', 'image processing'\n"
|
||||
f"2. Check spelling and try alternative terms\n"
|
||||
f"3. Try English keywords for better matching",
|
||||
)
|
||||
text="No matching DataJuicer operators found for "
|
||||
f"query: {query}\n"
|
||||
"Suggestions:\n"
|
||||
"1. Use more specific keywords like 'text filter', "
|
||||
"'image processing'\n"
|
||||
"2. Check spelling and try alternative terms\n"
|
||||
"3. Try English keywords for better matching",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -79,7 +96,7 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
retrieved_operators = _format_tool_names_to_class_entries(tool_names)
|
||||
|
||||
# Format response
|
||||
result_text = f"🔍 DataJuicer Operator Query Results\n"
|
||||
result_text = "🔍 DataJuicer Operator Query Results\n"
|
||||
result_text += f"Query: {query}\n"
|
||||
result_text += f"Limit: {limit} operators\n"
|
||||
result_text += f"{'='*50}\n\n"
|
||||
@@ -90,7 +107,7 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=result_text,
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -101,7 +118,7 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
type="text",
|
||||
text=f"Error querying DataJuicer operators: {str(e)}\n"
|
||||
f"Please verify query parameters and retry.",
|
||||
)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -109,10 +126,11 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
async def execute_safe_command(
|
||||
command: str,
|
||||
timeout: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Execute safe commands including DataJuicer commands and other safe system commands.
|
||||
Returns the return code, standard output and error within <returncode></returncode>,
|
||||
"""Execute safe commands including DataJuicer commands and other safe
|
||||
system commands.
|
||||
Returns the return code, standard output and error within
|
||||
<returncode></returncode>,
|
||||
<stdout></stdout> and <stderr></stderr> tags.
|
||||
|
||||
Args:
|
||||
@@ -131,39 +149,67 @@ async def execute_safe_command(
|
||||
The tool response containing the return code, standard output, and
|
||||
standard error of the executed command.
|
||||
"""
|
||||
|
||||
|
||||
# Security check: only allow safe commands
|
||||
command_stripped = command.strip()
|
||||
|
||||
|
||||
# Define allowed command prefixes for security
|
||||
allowed_commands = [
|
||||
# DataJuicer commands
|
||||
'dj-process', 'dj-analyze',
|
||||
"dj-process",
|
||||
"dj-analyze",
|
||||
# File system operations
|
||||
'mkdir', 'ls', 'pwd', 'cat', 'echo', 'cp', 'mv', 'rm',
|
||||
"mkdir",
|
||||
"ls",
|
||||
"pwd",
|
||||
"cat",
|
||||
"echo",
|
||||
"cp",
|
||||
"mv",
|
||||
"rm",
|
||||
# Text processing
|
||||
'grep', 'head', 'tail', 'wc', 'sort', 'uniq',
|
||||
"grep",
|
||||
"head",
|
||||
"tail",
|
||||
"wc",
|
||||
"sort",
|
||||
"uniq",
|
||||
# Archive operations
|
||||
'tar', 'zip', 'unzip',
|
||||
"tar",
|
||||
"zip",
|
||||
"unzip",
|
||||
# Information commands
|
||||
'which', 'whoami', 'date', 'find',
|
||||
"which",
|
||||
"whoami",
|
||||
"date",
|
||||
"find",
|
||||
# Python commands
|
||||
'python', 'python3', 'pip', 'uv'
|
||||
"python",
|
||||
"python3",
|
||||
"pip",
|
||||
"uv",
|
||||
]
|
||||
|
||||
|
||||
# Check if command starts with any allowed command
|
||||
command_allowed = False
|
||||
for allowed_cmd in allowed_commands:
|
||||
if command_stripped.startswith(allowed_cmd):
|
||||
# Additional security checks for potentially dangerous commands
|
||||
if allowed_cmd in ['rm', 'mv'] and ('/' in command_stripped or '..' in command_stripped):
|
||||
if allowed_cmd in ["rm", "mv"] and (
|
||||
"/" in command_stripped or ".." in command_stripped
|
||||
):
|
||||
# Prevent dangerous path operations
|
||||
continue
|
||||
command_allowed = True
|
||||
break
|
||||
|
||||
|
||||
if not command_allowed:
|
||||
error_msg = f"Error: Command not allowed for security reasons. Allowed commands: {', '.join(allowed_commands)}. Received command: {command}"
|
||||
error_msg = (
|
||||
"Error: Command not allowed for security reasons. "
|
||||
"Allowed commands: "
|
||||
f"{', '.join(allowed_commands)}. "
|
||||
f"Received command: {command}"
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
@@ -193,7 +239,7 @@ async def execute_safe_command(
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
stderr_suffix = (
|
||||
f"TimeoutError: The command execution exceeded "
|
||||
"TimeoutError: The command execution exceeded "
|
||||
f"the timeout of {timeout} seconds."
|
||||
)
|
||||
returncode = -1
|
||||
@@ -221,4 +267,4 @@ async def execute_safe_command(
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import Optional
|
||||
import string
|
||||
|
||||
from agentscope.tool import Toolkit
|
||||
from agentscope.mcp import HttpStatefulClient, HttpStatelessClient, StdIOStatefulClient
|
||||
from agentscope.mcp import (
|
||||
HttpStatefulClient,
|
||||
HttpStatelessClient,
|
||||
StdIOStatefulClient,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -13,6 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
root_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
|
||||
def _load_config(config_path: str) -> dict:
|
||||
"""Load MCP configuration from file"""
|
||||
try:
|
||||
@@ -23,13 +29,15 @@ def _load_config(config_path: str) -> dict:
|
||||
return config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Configuration file {config_path} not found, using default settings"
|
||||
f"Configuration file {config_path} not found, "
|
||||
"using default settings",
|
||||
)
|
||||
return _create_default_config()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration: {e}")
|
||||
return _create_default_config()
|
||||
|
||||
|
||||
def _create_default_config() -> dict:
|
||||
"""Create default configuration"""
|
||||
return {
|
||||
@@ -38,10 +46,11 @@ def _create_default_config() -> dict:
|
||||
"command": "python",
|
||||
"args": ["/home/test/data_juicer/tools/DJ_mcp_recipe_flow.py"],
|
||||
"env": {"SERVER_TRANSPORT": "stdio"},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _expand_env_vars(value: str) -> str:
|
||||
"""Expand environment variables in configuration values"""
|
||||
if isinstance(value, str):
|
||||
@@ -53,6 +62,7 @@ def _expand_env_vars(value: str) -> str:
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
async def _create_clients(config: dict, toolkit: Toolkit):
|
||||
"""Create MCP clients based on configuration"""
|
||||
server_configs = config.get("mcpServers", {})
|
||||
@@ -88,33 +98,38 @@ async def _create_clients(config: dict, toolkit: Toolkit):
|
||||
|
||||
if stateful:
|
||||
client = HttpStatefulClient(
|
||||
name=server_name, transport=transport, url=url
|
||||
name=server_name,
|
||||
transport=transport,
|
||||
url=url,
|
||||
)
|
||||
await client.connect()
|
||||
await toolkit.register_mcp_client(client)
|
||||
else:
|
||||
client = HttpStatelessClient(
|
||||
name=server_name, transport=transport, url=url
|
||||
name=server_name,
|
||||
transport=transport,
|
||||
url=url,
|
||||
)
|
||||
await toolkit.register_mcp_client(client)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid server configuration")
|
||||
|
||||
|
||||
clients.append(client)
|
||||
except Exception as e:
|
||||
if "Invalid server configuration" in str(e):
|
||||
raise e
|
||||
logger.error(f"Failed to create client {server_name}: {e}")
|
||||
|
||||
|
||||
return clients
|
||||
|
||||
|
||||
async def get_mcp_toolkit(config_path: Optional[str] = None) -> Toolkit:
|
||||
"""Get toolkit with all MCP tools registered"""
|
||||
config_path = config_path or root_path + "/configs/mcp_config.json"
|
||||
config = _load_config(config_path)
|
||||
toolkit = Toolkit()
|
||||
|
||||
|
||||
clients = await _create_clients(config, toolkit)
|
||||
|
||||
|
||||
return toolkit, clients
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import inspect
|
||||
from data_juicer.tools.op_search import OPSearcher
|
||||
|
||||
@@ -7,7 +8,11 @@ all_ops = searcher.search()
|
||||
|
||||
dj_func_info = []
|
||||
for i, op in enumerate(all_ops):
|
||||
class_entry = {"index": i, "class_name": op["name"], "class_desc": op["desc"]}
|
||||
class_entry = {
|
||||
"index": i,
|
||||
"class_name": op["name"],
|
||||
"class_desc": op["desc"],
|
||||
}
|
||||
param_desc = op["param_desc"]
|
||||
param_desc_map = {}
|
||||
args = ""
|
||||
@@ -27,7 +32,8 @@ for i, op in enumerate(all_ops):
|
||||
):
|
||||
continue
|
||||
if param_name in param_desc_map:
|
||||
args += f" {param_name} ({param.annotation}): {param_desc_map[param_name]}\n"
|
||||
args += f" {param_name} ({param.annotation}):"
|
||||
args += f" {param_desc_map[param_name]}\n"
|
||||
else:
|
||||
args += f" {param_name} ({param.annotation})\n"
|
||||
class_entry["arguments"] = args
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Optional
|
||||
@@ -18,17 +18,22 @@ _cached_vector_store: Optional[FAISS] = None
|
||||
_cached_tools_info: Optional[list] = None
|
||||
_cached_file_hash: Optional[str] = None
|
||||
|
||||
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:
|
||||
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant
|
||||
responsible for filtering the top {limit} most relevant tools from a large
|
||||
tool library based on user requirements. Execute the following steps:
|
||||
|
||||
# Requirement Analysis
|
||||
Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).
|
||||
Carefully read the user's [requirement description], extract core keywords,
|
||||
functional objectives, usage scenarios, and technical requirements
|
||||
(such as real-time performance, data types, industry domains, etc.).
|
||||
|
||||
# Tool Matching
|
||||
Perform multi-dimensional matching based on the following tool attributes:
|
||||
- Tool name and functional description
|
||||
- Supported input/output formats
|
||||
- Applicable industry or scenario tags
|
||||
- Technical implementation principles (API, local deployment, AI model types)
|
||||
- Technical implementation principles
|
||||
(API, local deployment, AI model types)
|
||||
- Relevance ranking
|
||||
|
||||
# Use weighted scoring mechanism (example weights):
|
||||
@@ -59,7 +64,8 @@ RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsibl
|
||||
"key_match": ["Matching keywords/features"]
|
||||
}}
|
||||
]
|
||||
Output strictly in JSON array format, and only output the JSON array format tool list.
|
||||
Output strictly in JSON array format, and only output the JSON array format
|
||||
tool list.
|
||||
"""
|
||||
|
||||
|
||||
@@ -96,9 +102,15 @@ async def retrieve_ops_lm(user_query, limit=20):
|
||||
else:
|
||||
from create_dj_func_info import dj_func_info
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), ".."),
|
||||
)
|
||||
|
||||
with open(os.path.join(project_root, TOOLS_INFO_PATH), "w") as f:
|
||||
with open(
|
||||
os.path.join(project_root, TOOLS_INFO_PATH),
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
f.write(json.dumps(dj_func_info))
|
||||
|
||||
tool_descriptions = [
|
||||
@@ -123,15 +135,13 @@ async def retrieve_ops_lm(user_query, limit=20):
|
||||
|
||||
user_prompt = (
|
||||
retrieval_prompt_with_limit
|
||||
+ """
|
||||
User requirement description:
|
||||
+ f"""
|
||||
User requirement description:
|
||||
{user_query}
|
||||
|
||||
Available tools:
|
||||
{tools_string}
|
||||
""".format(
|
||||
user_query=user_query, tools_string=tools_string
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
msgs = [
|
||||
@@ -191,13 +201,11 @@ def _load_cached_index() -> bool:
|
||||
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
|
||||
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
|
||||
|
||||
if not all(
|
||||
os.path.exists(p) for p in [index_path, metadata_path]
|
||||
):
|
||||
if not all(os.path.exists(p) for p in [index_path, metadata_path]):
|
||||
return False
|
||||
|
||||
# Check if cached index matches current tools info file
|
||||
with open(metadata_path, "r") as f:
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
cached_hash = metadata.get("tools_info_hash", "")
|
||||
@@ -215,7 +223,9 @@ def _load_cached_index() -> bool:
|
||||
)
|
||||
|
||||
_cached_vector_store = FAISS.load_local(
|
||||
index_path, embeddings, allow_dangerous_deserialization=True
|
||||
index_path,
|
||||
embeddings,
|
||||
allow_dangerous_deserialization=True,
|
||||
)
|
||||
|
||||
_cached_file_hash = cached_hash
|
||||
@@ -244,8 +254,11 @@ def _save_cached_index():
|
||||
_cached_vector_store.save_local(index_path)
|
||||
|
||||
# Save metadata
|
||||
metadata = {"tools_info_hash": _cached_file_hash, "created_at": time.time()}
|
||||
with open(metadata_path, "w") as f:
|
||||
metadata = {
|
||||
"tools_info_hash": _cached_file_hash,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
logging.info("Successfully saved vector index to cache")
|
||||
@@ -261,16 +274,23 @@ def _build_vector_index():
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
tools_info = json.loads(f.read())
|
||||
|
||||
tool_descriptions = [f"{t['class_name']}: {t['class_desc']}" for t in tools_info]
|
||||
tool_descriptions = [
|
||||
f"{t['class_name']}: {t['class_desc']}" for t in tools_info
|
||||
]
|
||||
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
|
||||
embeddings = DashScopeEmbeddings(
|
||||
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v1"
|
||||
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
|
||||
model="text-embedding-v1",
|
||||
)
|
||||
|
||||
metadatas = [{"index": i} for i in range(len(tool_descriptions))]
|
||||
vector_store = FAISS.from_texts(tool_descriptions, embeddings, metadatas=metadatas)
|
||||
vector_store = FAISS.from_texts(
|
||||
tool_descriptions,
|
||||
embeddings,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
# Cache the results
|
||||
_cached_vector_store = vector_store
|
||||
@@ -283,7 +303,7 @@ def _build_vector_index():
|
||||
|
||||
|
||||
def retrieve_ops_vector(user_query, limit=20):
|
||||
"""Tool retrieval using vector search with caching - returns list of tool names"""
|
||||
"""Tool retrieval using vector search with caching"""
|
||||
global _cached_vector_store
|
||||
|
||||
# Try to load from cache first
|
||||
@@ -292,7 +312,10 @@ def retrieve_ops_vector(user_query, limit=20):
|
||||
_build_vector_index()
|
||||
|
||||
# Perform similarity search
|
||||
retrieved_tools = _cached_vector_store.similarity_search(user_query, k=limit)
|
||||
retrieved_tools = _cached_vector_store.similarity_search(
|
||||
user_query,
|
||||
k=limit,
|
||||
)
|
||||
retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools]
|
||||
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
@@ -307,7 +330,11 @@ def retrieve_ops_vector(user_query, limit=20):
|
||||
return tool_names
|
||||
|
||||
|
||||
async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") -> list:
|
||||
async def retrieve_ops(
|
||||
user_query: str,
|
||||
limit: int = 20,
|
||||
mode: str = "auto",
|
||||
) -> list:
|
||||
"""
|
||||
Tool retrieval with configurable mode
|
||||
|
||||
@@ -322,59 +349,56 @@ async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") ->
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
if mode == "llm":
|
||||
if mode in ("llm", "auto"):
|
||||
try:
|
||||
return await retrieve_ops_lm(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
logging.error(f"LLM retrieval failed: {str(e)}")
|
||||
return []
|
||||
if mode != "auto":
|
||||
return []
|
||||
|
||||
elif mode == "vector":
|
||||
if mode in ("vector", "auto"):
|
||||
try:
|
||||
return retrieve_ops_vector(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
logging.error(f"Vector retrieval failed: {str(e)}")
|
||||
return []
|
||||
|
||||
elif mode == "auto":
|
||||
try:
|
||||
return await retrieve_ops_lm(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
try:
|
||||
return retrieve_ops_vector(user_query, limit=limit)
|
||||
except Exception as fallback_e:
|
||||
logging.error(
|
||||
f"Tool retrieval failed: {str(e)}, fallback retrieval also failed: {str(fallback_e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'")
|
||||
raise ValueError(
|
||||
f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
user_query = (
|
||||
"Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms."
|
||||
+ "Additionally, filter out small images, perform image tagging, and remove duplicate images."
|
||||
query = (
|
||||
"Clean special characters from text and filter samples with "
|
||||
+ "excessive length. Mask sensitive information and filter "
|
||||
+ "unsafe content including adult/terror-related terms."
|
||||
+ "Additionally, filter out small images, perform image "
|
||||
+ "tagging, and remove duplicate images."
|
||||
)
|
||||
|
||||
# Test different modes
|
||||
print("=== Testing LLM mode ===")
|
||||
tool_names_llm = asyncio.run(retrieve_ops(user_query, limit=10, mode="llm"))
|
||||
tool_names_llm = asyncio.run(
|
||||
retrieve_ops(query, limit=10, mode="llm"),
|
||||
)
|
||||
print("Retrieved tool names (LLM):")
|
||||
print(tool_names_llm)
|
||||
|
||||
print("\n=== Testing Vector mode ===")
|
||||
tool_names_vector = asyncio.run(retrieve_ops(user_query, limit=10, mode="vector"))
|
||||
tool_names_vector = asyncio.run(
|
||||
retrieve_ops(query, limit=10, mode="vector"),
|
||||
)
|
||||
print("Retrieved tool names (Vector):")
|
||||
print(tool_names_vector)
|
||||
|
||||
print("\n=== Testing Auto mode (default) ===")
|
||||
tool_names_auto = asyncio.run(retrieve_ops(user_query, limit=10, mode="auto"))
|
||||
tool_names_auto = asyncio.run(
|
||||
retrieve_ops(query, limit=10, mode="auto"),
|
||||
)
|
||||
print("Retrieved tool names (Auto):")
|
||||
print(tool_names_auto)
|
||||
|
||||
@@ -7,7 +7,9 @@ from agentscope.tool import ToolResponse
|
||||
|
||||
|
||||
def agent_to_tool(
|
||||
agent: AgentBase, tool_name: str = None, description: str = None
|
||||
agent: AgentBase,
|
||||
tool_name: str = None,
|
||||
description: str = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Convert any agent to a tool function that can be registered in toolkit.
|
||||
@@ -15,10 +17,12 @@ def agent_to_tool(
|
||||
Args:
|
||||
agent: The agent instance to convert
|
||||
tool_name: Optional custom tool name (defaults to agent.name)
|
||||
description: Optional tool description (defaults to agent's docstring or sys_prompt)
|
||||
description: Optional tool description
|
||||
(defaults to agent's docstring or sys_prompt)
|
||||
|
||||
Returns:
|
||||
A tool function that can be registered with toolkit.register_tool_function()
|
||||
A tool function that can be registered with
|
||||
toolkit.register_tool_function()
|
||||
"""
|
||||
# Get tool name and description
|
||||
if tool_name is None:
|
||||
@@ -30,8 +34,6 @@ def agent_to_tool(
|
||||
description = agent.__doc__.strip()
|
||||
elif hasattr(agent, "sys_prompt"):
|
||||
description = f"Agent: {agent.sys_prompt[:100]}..."
|
||||
elif hasattr(agent, "_sys_prompt"):
|
||||
description = f"Agent: {agent._sys_prompt[:100]}..."
|
||||
else:
|
||||
description = f"Tool function for {tool_name}"
|
||||
|
||||
@@ -56,7 +58,8 @@ def agent_to_tool(
|
||||
# Set function name and docstring
|
||||
tool_function.__name__ = f"call_{tool_name.lower().replace(' ', '_')}"
|
||||
tool_function.__doc__ = (
|
||||
f"{description}\n\nArgs:\n task (str): The task for {tool_name} to handle"
|
||||
f"{description}\n\nArgs:"
|
||||
+ "\n task (str): The task for {tool_name} to handle"
|
||||
)
|
||||
|
||||
return tool_function
|
||||
|
||||
Reference in New Issue
Block a user