Optimize DataJuicer Agent doc & linter (#30)

This commit is contained in:
Daoyuan Chen
2025-11-10 18:17:27 +08:00
committed by GitHub
parent 1f0c5de27f
commit dba3b86ddf
14 changed files with 891 additions and 359 deletions

View File

@@ -16,17 +16,23 @@ from agentscope.tool import Toolkit
from .dj_helpers import execute_safe_command
from .router_helpers import agent_to_tool
from .dj_helpers import query_dj_operators
from .dj_dev_helpers import get_basic_files, get_operator_example, configure_data_juicer_path
from .dj_dev_helpers import (
get_basic_files,
get_operator_example,
configure_data_juicer_path,
)
from .mcp_helpers import get_mcp_toolkit
def create_toolkit(tools: List[str]):
def create_toolkit(tools: List[AgentBase]):
# Create toolkit and register tools
toolkit = Toolkit()
for tool in tools:
toolkit.register_tool_function(tool)
return toolkit
# DJ Agent tools
dj_tools = [
execute_safe_command,
@@ -50,10 +56,12 @@ mcp_tools = [
write_text_file,
]
def agents2toolkit(agents: List[AgentBase]):
tools = [agent_to_tool(agent) for agent in agents]
return create_toolkit(tools)
dj_toolkit = create_toolkit(dj_tools)
dj_dev_toolkit = create_toolkit(dj_dev_tools)
@@ -71,7 +79,6 @@ __all__ = [
"dj_tools",
"dj_dev_tools",
"mcp_tools",
"all_tools",
"agents2toolkit",
"dj_toolkit",
"dj_dev_toolkit",
@@ -85,4 +92,4 @@ __all__ = [
"get_basic_files",
"get_operator_example",
"configure_data_juicer_path",
]
]

View File

@@ -2,8 +2,8 @@
"""
DataJuicer Development Tools
Tools for developing DataJuicer operators, including access to basic documentation
and example code for different operator types.
Tools for developing DataJuicer operators, including access to basic
documentation and example code for different operator types.
"""
import os
@@ -23,7 +23,8 @@ BASIC_LIST_RELATIVE = [
def get_basic_files() -> ToolResponse:
"""Get basic DataJuicer development files content.
Returns the content of essential files needed for DJ operator development:
Returns the content of essential files needed for DJ operator
development:
- base_op.py: Base operator class
- DeveloperGuide.md: English developer guide
- DeveloperGuide_ZH.md: Chinese developer guide
@@ -31,19 +32,23 @@ def get_basic_files() -> ToolResponse:
Returns:
ToolResponse: Combined content of all basic development files
"""
global DATA_JUICER_PATH, BASIC_LIST_RELATIVE
if DATA_JUICER_PATH is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
)
]
text=(
"DATA_JUICER_PATH is not configured. Please ask the "
"user to provide the DATA_JUICER_PATH"
),
),
],
)
try:
combined_content = "# DataJuicer Operator Development Basic Files\n\n"
comb_content = "# DataJuicer Operator Development Basic Files\n\n"
for relative_path in BASIC_LIST_RELATIVE:
file_path = os.path.join(DATA_JUICER_PATH, relative_path)
@@ -52,20 +57,21 @@ def get_basic_files() -> ToolResponse:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
filename = os.path.basename(file_path)
combined_content += f"## {filename}\n\n"
combined_content += (
f"```{'python' if filename.endswith('.py') else 'markdown'}\n"
)
combined_content += content
combined_content += "\n```\n\n"
file_n = os.path.basename(file_path)
comb_content += f"## {file_n}\n\n```"
flag = "python" if file_n.endswith(".py") else "markdown"
comb_content += f"{flag}\n"
comb_content += content
comb_content += "\n```\n\n"
except Exception as e:
combined_content += (
comb_content += (
f"## {os.path.basename(file_path)} (Read Failed)\n"
)
combined_content += f"Error: {str(e)}\n\n"
comb_content += f"Error: {str(e)}\n\n"
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
return ToolResponse(
content=[TextBlock(type="text", text=comb_content)],
)
except Exception as e:
return ToolResponse(
@@ -73,32 +79,41 @@ def get_basic_files() -> ToolResponse:
TextBlock(
type="text",
text=f"Error occurred while getting basic files: {str(e)}",
)
]
),
],
)
async def get_operator_example(
requirement_description: str, limit: int = 2
requirement_description: str,
limit: int = 2,
) -> ToolResponse:
"""Get example operators based on requirement description using dynamic search.
"""Get example operators based on requirement description using
dynamic search.
Args:
requirement_description (str): Natural language description of the operator requirement
limit (int): Maximum number of example operators to return (default: 2)
requirement_description (str): Natural language description of
the operator requirement
limit (int): Maximum number of example operators to return
(default: 2)
Returns:
ToolResponse: Example operator code and test files based on the requirement
ToolResponse: Example operator code and test files based on
the requirement
"""
global DATA_JUICER_PATH
if DATA_JUICER_PATH is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
)
]
text=(
"DATA_JUICER_PATH is not configured. Please ask the "
"user to provide the DATA_JUICER_PATH"
),
),
],
)
try:
@@ -108,49 +123,56 @@ async def get_operator_example(
# Query relevant operators using the requirement description
# Use retrieval mode from environment variable if set
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
tool_names = await retrieve_ops(requirement_description, limit=limit, mode=retrieval_mode)
tool_names = await retrieve_ops(
requirement_description,
limit=limit,
mode=retrieval_mode,
)
if not tool_names:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"No relevant operators found for requirement: {requirement_description}\n"
f"Please try with more specific keywords or check if DATA_JUICER_PATH is properly configured.",
)
]
text=(
"No relevant operators found for requirement: "
f"{requirement_description}\n"
"Please try with more specific keywords or "
"check if DATA_JUICER_PATH is properly "
"configured."
),
),
],
)
combined_content = (
comb_content = (
f"# Dynamic Operator Examples for: {requirement_description}\n\n"
)
combined_content += (
comb_content += (
f"Found {len(tool_names)} relevant operators (limit: {limit})\n\n"
)
# Process each found operator
for i, tool_name in enumerate(tool_names[:limit]):
combined_content += f"## {i+1}. {tool_name}\n\n"
comb_content += f"## {i+1}. {tool_name}\n\n"
op_type = tool_name.split("_")[-1]
operator_path = f"data_juicer/ops/{op_type}/{tool_name}.py"
# Try to find operator source file
full_path = os.path.join(DATA_JUICER_PATH, operator_path)
if os.path.exists(full_path):
with open(full_path, "r", encoding="utf-8") as f:
operator_code = f.read()
combined_content += f"### Source Code\n"
combined_content += "```python\n"
combined_content += operator_code
combined_content += "\n```\n\n"
comb_content += "### Source Code\n"
comb_content += "```python\n"
comb_content += operator_code
comb_content += "\n```\n\n"
else:
combined_content += (
f"**Note:** Source code file not found for `{tool_name}`.\n\n"
)
comb_content += "**Note:** Source code file not found for"
comb_content += f" `{tool_name}`.\n\n"
test_path = f"tests/ops/{op_type}/test_{tool_name}.py"
@@ -159,36 +181,43 @@ async def get_operator_example(
with open(full_test_path, "r", encoding="utf-8") as f:
test_code = f.read()
combined_content += f"### Test Code\n"
combined_content += f"**File Path:** `{test_path}`\n\n"
combined_content += "```python\n"
combined_content += test_code
combined_content += "\n```\n\n"
comb_content += "### Test Code\n"
comb_content += f"**File Path:** `{test_path}`\n\n"
comb_content += "```python\n"
comb_content += test_code
comb_content += "\n```\n\n"
else:
combined_content += (
comb_content += (
f"**Note:** Test file not found for `{tool_name}`.\n\n"
)
combined_content += "---\n\n"
comb_content += "---\n\n"
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
return ToolResponse(
content=[TextBlock(type="text", text=comb_content)],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error occurred while getting operator examples: {str(e)}\n"
f"Please check the requirement description and try again.",
)
]
text=(
"Error occurred while getting operator examples: "
f"{str(e)}\n"
"Please check the requirement description and try "
"again."
),
),
],
)
def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
"""Configure DataJuicer path.
If the user provides the data_juicer_path, please use this method to configure it.
If the user provides the data_juicer_path, please use this method to
configure it.
Args:
data_juicer_path (str): Path to DataJuicer installation
@@ -196,8 +225,9 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
Returns:
ToolResponse: Configuration result
"""
global DATA_JUICER_PATH
data_juicer_path = os.path.expanduser(data_juicer_path)
try:
@@ -206,9 +236,12 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
content=[
TextBlock(
type="text",
text=f"Specified DataJuicer path does not exist: {data_juicer_path}",
)
]
text=(
"Specified DataJuicer path does not exist: "
f"{data_juicer_path}"
),
),
],
)
# Update global DATA_JUICER_PATH
@@ -218,9 +251,12 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
content=[
TextBlock(
type="text",
text=f"DataJuicer path has been updated to: {DATA_JUICER_PATH}",
)
]
text=(
"DataJuicer path has been updated to: ",
f"{DATA_JUICER_PATH}",
),
),
],
)
except Exception as e:
@@ -228,7 +264,10 @@ def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
content=[
TextBlock(
type="text",
text=f"Error occurred while configuring DataJuicer path: {str(e)}",
)
]
text=(
"Error occurred while configuring DataJuicer path: "
f"{str(e)}"
),
),
],
)

View File

@@ -1,14 +1,19 @@
# -*- coding: utf-8 -*-
import os
import os.path as osp
import json
import asyncio
from typing import Any
from agentscope.message import TextBlock
from agentscope.tool import ToolResponse
from .op_manager.op_retrieval import retrieve_ops
# Load tool information for formatting
TOOLS_INFO_PATH = osp.join(osp.dirname(__file__), "op_manager", "dj_funcs_all.json")
TOOLS_INFO_PATH = osp.join(
osp.dirname(__file__),
"op_manager",
"dj_funcs_all.json",
)
def _load_tools_info():
"""Load tools information from JSON file or create it if not exists"""
@@ -17,30 +22,35 @@ def _load_tools_info():
return json.loads(f.read())
else:
from .op_manager.create_dj_func_info import dj_func_info
with open(TOOLS_INFO_PATH, "w", encoding="utf-8") as f:
json.dump(dj_func_info, f)
return dj_func_info
def _format_tool_names_to_class_entries(tool_names):
"""Convert tool names list to formatted class entries string"""
if not tool_names:
return ""
tools_info = _load_tools_info()
# Create a mapping from class_name to tool info for quick lookup
tools_map = {tool['class_name']: tool for tool in tools_info}
tools_map = {tool["class_name"]: tool for tool in tools_info}
formatted_entries = []
for i, tool_name in enumerate(tool_names):
if tool_name in tools_map:
tool_info = tools_map[tool_name]
class_entry = f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
class_entry = (
f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
)
class_entry += "\n" + tool_info["arguments"]
formatted_entries.append(class_entry)
return "\n".join(formatted_entries)
async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
"""Query DataJuicer operators by natural language description.
@@ -52,26 +62,33 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
limit (int): Maximum number of operators to return (default: 20)
Returns:
ToolResponse: Tool response containing matched operators with names, descriptions, and parameters
ToolResponse: Tool response containing matched operators with names,
descriptions, and parameters
"""
try:
# Retrieve operator names using existing functionality with limit
# Use retrieval mode from environment variable if set
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
tool_names = await retrieve_ops(query, limit=limit, mode=retrieval_mode)
tool_names = await retrieve_ops(
query,
limit=limit,
mode=retrieval_mode,
)
if not tool_names:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"No matching DataJuicer operators found for query: {query}\n"
f"Suggestions:\n"
f"1. Use more specific keywords like 'text filter', 'image processing'\n"
f"2. Check spelling and try alternative terms\n"
f"3. Try English keywords for better matching",
)
text="No matching DataJuicer operators found for "
f"query: {query}\n"
"Suggestions:\n"
"1. Use more specific keywords like 'text filter', "
"'image processing'\n"
"2. Check spelling and try alternative terms\n"
"3. Try English keywords for better matching",
),
],
)
@@ -79,7 +96,7 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
retrieved_operators = _format_tool_names_to_class_entries(tool_names)
# Format response
result_text = f"🔍 DataJuicer Operator Query Results\n"
result_text = "🔍 DataJuicer Operator Query Results\n"
result_text += f"Query: {query}\n"
result_text += f"Limit: {limit} operators\n"
result_text += f"{'='*50}\n\n"
@@ -90,7 +107,7 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
TextBlock(
type="text",
text=result_text,
)
),
],
)
@@ -101,7 +118,7 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
type="text",
text=f"Error querying DataJuicer operators: {str(e)}\n"
f"Please verify query parameters and retry.",
)
),
],
)
@@ -109,10 +126,11 @@ async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
async def execute_safe_command(
command: str,
timeout: int = 300,
**kwargs: Any,
) -> ToolResponse:
"""Execute safe commands including DataJuicer commands and other safe system commands.
Returns the return code, standard output and error within <returncode></returncode>,
"""Execute safe commands including DataJuicer commands and other safe
system commands.
Returns the return code, standard output and error within
<returncode></returncode>,
<stdout></stdout> and <stderr></stderr> tags.
Args:
@@ -131,39 +149,67 @@ async def execute_safe_command(
The tool response containing the return code, standard output, and
standard error of the executed command.
"""
# Security check: only allow safe commands
command_stripped = command.strip()
# Define allowed command prefixes for security
allowed_commands = [
# DataJuicer commands
'dj-process', 'dj-analyze',
"dj-process",
"dj-analyze",
# File system operations
'mkdir', 'ls', 'pwd', 'cat', 'echo', 'cp', 'mv', 'rm',
"mkdir",
"ls",
"pwd",
"cat",
"echo",
"cp",
"mv",
"rm",
# Text processing
'grep', 'head', 'tail', 'wc', 'sort', 'uniq',
"grep",
"head",
"tail",
"wc",
"sort",
"uniq",
# Archive operations
'tar', 'zip', 'unzip',
"tar",
"zip",
"unzip",
# Information commands
'which', 'whoami', 'date', 'find',
"which",
"whoami",
"date",
"find",
# Python commands
'python', 'python3', 'pip', 'uv'
"python",
"python3",
"pip",
"uv",
]
# Check if command starts with any allowed command
command_allowed = False
for allowed_cmd in allowed_commands:
if command_stripped.startswith(allowed_cmd):
# Additional security checks for potentially dangerous commands
if allowed_cmd in ['rm', 'mv'] and ('/' in command_stripped or '..' in command_stripped):
if allowed_cmd in ["rm", "mv"] and (
"/" in command_stripped or ".." in command_stripped
):
# Prevent dangerous path operations
continue
command_allowed = True
break
if not command_allowed:
error_msg = f"Error: Command not allowed for security reasons. Allowed commands: {', '.join(allowed_commands)}. Received command: {command}"
error_msg = (
"Error: Command not allowed for security reasons. "
"Allowed commands: "
f"{', '.join(allowed_commands)}. "
f"Received command: {command}"
)
return ToolResponse(
content=[
TextBlock(
@@ -193,7 +239,7 @@ async def execute_safe_command(
except asyncio.TimeoutError:
stderr_suffix = (
f"TimeoutError: The command execution exceeded "
"TimeoutError: The command execution exceeded "
f"the timeout of {timeout} seconds."
)
returncode = -1
@@ -221,4 +267,4 @@ async def execute_safe_command(
),
),
],
)
)

View File

@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-
import json
import os
import logging
from typing import Optional, List
from typing import Optional
import string
from agentscope.tool import Toolkit
from agentscope.mcp import HttpStatefulClient, HttpStatelessClient, StdIOStatefulClient
from agentscope.mcp import (
HttpStatefulClient,
HttpStatelessClient,
StdIOStatefulClient,
)
# Configure logging
logging.basicConfig(level=logging.INFO)
@@ -13,6 +18,7 @@ logger = logging.getLogger(__name__)
root_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def _load_config(config_path: str) -> dict:
"""Load MCP configuration from file"""
try:
@@ -23,13 +29,15 @@ def _load_config(config_path: str) -> dict:
return config
else:
logger.warning(
f"Configuration file {config_path} not found, using default settings"
f"Configuration file {config_path} not found, "
"using default settings",
)
return _create_default_config()
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return _create_default_config()
def _create_default_config() -> dict:
"""Create default configuration"""
return {
@@ -38,10 +46,11 @@ def _create_default_config() -> dict:
"command": "python",
"args": ["/home/test/data_juicer/tools/DJ_mcp_recipe_flow.py"],
"env": {"SERVER_TRANSPORT": "stdio"},
}
}
},
},
}
def _expand_env_vars(value: str) -> str:
"""Expand environment variables in configuration values"""
if isinstance(value, str):
@@ -53,6 +62,7 @@ def _expand_env_vars(value: str) -> str:
return value
return value
async def _create_clients(config: dict, toolkit: Toolkit):
"""Create MCP clients based on configuration"""
server_configs = config.get("mcpServers", {})
@@ -88,33 +98,38 @@ async def _create_clients(config: dict, toolkit: Toolkit):
if stateful:
client = HttpStatefulClient(
name=server_name, transport=transport, url=url
name=server_name,
transport=transport,
url=url,
)
await client.connect()
await toolkit.register_mcp_client(client)
else:
client = HttpStatelessClient(
name=server_name, transport=transport, url=url
name=server_name,
transport=transport,
url=url,
)
await toolkit.register_mcp_client(client)
else:
raise ValueError("Invalid server configuration")
clients.append(client)
except Exception as e:
if "Invalid server configuration" in str(e):
raise e
logger.error(f"Failed to create client {server_name}: {e}")
return clients
async def get_mcp_toolkit(config_path: Optional[str] = None) -> Toolkit:
"""Get toolkit with all MCP tools registered"""
config_path = config_path or root_path + "/configs/mcp_config.json"
config = _load_config(config_path)
toolkit = Toolkit()
clients = await _create_clients(config, toolkit)
return toolkit, clients

View File

@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import inspect
from data_juicer.tools.op_search import OPSearcher
@@ -7,7 +8,11 @@ all_ops = searcher.search()
dj_func_info = []
for i, op in enumerate(all_ops):
class_entry = {"index": i, "class_name": op["name"], "class_desc": op["desc"]}
class_entry = {
"index": i,
"class_name": op["name"],
"class_desc": op["desc"],
}
param_desc = op["param_desc"]
param_desc_map = {}
args = ""
@@ -27,7 +32,8 @@ for i, op in enumerate(all_ops):
):
continue
if param_name in param_desc_map:
args += f" {param_name} ({param.annotation}): {param_desc_map[param_name]}\n"
args += f" {param_name} ({param.annotation}):"
args += f" {param_desc_map[param_name]}\n"
else:
args += f" {param_name} ({param.annotation})\n"
class_entry["arguments"] = args

View File

@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
import os
import os.path as osp
import json
import logging
import pickle
import hashlib
import time
from typing import Optional
@@ -18,17 +18,22 @@ _cached_vector_store: Optional[FAISS] = None
_cached_tools_info: Optional[list] = None
_cached_file_hash: Optional[str] = None
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant
responsible for filtering the top {limit} most relevant tools from a large
tool library based on user requirements. Execute the following steps:
# Requirement Analysis
Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).
Carefully read the user's [requirement description], extract core keywords,
functional objectives, usage scenarios, and technical requirements
(such as real-time performance, data types, industry domains, etc.).
# Tool Matching
Perform multi-dimensional matching based on the following tool attributes:
- Tool name and functional description
- Supported input/output formats
- Applicable industry or scenario tags
- Technical implementation principles (API, local deployment, AI model types)
- Technical implementation principles
(API, local deployment, AI model types)
- Relevance ranking
# Use weighted scoring mechanism (example weights):
@@ -59,7 +64,8 @@ RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsibl
"key_match": ["Matching keywords/features"]
}}
]
Output strictly in JSON array format, and only output the JSON array format tool list.
Output strictly in JSON array format, and only output the JSON array format
tool list.
"""
@@ -96,9 +102,15 @@ async def retrieve_ops_lm(user_query, limit=20):
else:
from create_dj_func_info import dj_func_info
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
)
with open(os.path.join(project_root, TOOLS_INFO_PATH), "w") as f:
with open(
os.path.join(project_root, TOOLS_INFO_PATH),
"w",
encoding="utf-8",
) as f:
f.write(json.dumps(dj_func_info))
tool_descriptions = [
@@ -123,15 +135,13 @@ async def retrieve_ops_lm(user_query, limit=20):
user_prompt = (
retrieval_prompt_with_limit
+ """
User requirement description:
+ f"""
User requirement description:
{user_query}
Available tools:
{tools_string}
""".format(
user_query=user_query, tools_string=tools_string
)
"""
)
msgs = [
@@ -191,13 +201,11 @@ def _load_cached_index() -> bool:
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
if not all(
os.path.exists(p) for p in [index_path, metadata_path]
):
if not all(os.path.exists(p) for p in [index_path, metadata_path]):
return False
# Check if cached index matches current tools info file
with open(metadata_path, "r") as f:
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
cached_hash = metadata.get("tools_info_hash", "")
@@ -215,7 +223,9 @@ def _load_cached_index() -> bool:
)
_cached_vector_store = FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True
index_path,
embeddings,
allow_dangerous_deserialization=True,
)
_cached_file_hash = cached_hash
@@ -244,8 +254,11 @@ def _save_cached_index():
_cached_vector_store.save_local(index_path)
# Save metadata
metadata = {"tools_info_hash": _cached_file_hash, "created_at": time.time()}
with open(metadata_path, "w") as f:
metadata = {
"tools_info_hash": _cached_file_hash,
"created_at": time.time(),
}
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f)
logging.info("Successfully saved vector index to cache")
@@ -261,16 +274,23 @@ def _build_vector_index():
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
tools_info = json.loads(f.read())
tool_descriptions = [f"{t['class_name']}: {t['class_desc']}" for t in tools_info]
tool_descriptions = [
f"{t['class_name']}: {t['class_desc']}" for t in tools_info
]
from langchain_community.embeddings import DashScopeEmbeddings
embeddings = DashScopeEmbeddings(
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v1"
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
model="text-embedding-v1",
)
metadatas = [{"index": i} for i in range(len(tool_descriptions))]
vector_store = FAISS.from_texts(tool_descriptions, embeddings, metadatas=metadatas)
vector_store = FAISS.from_texts(
tool_descriptions,
embeddings,
metadatas=metadatas,
)
# Cache the results
_cached_vector_store = vector_store
@@ -283,7 +303,7 @@ def _build_vector_index():
def retrieve_ops_vector(user_query, limit=20):
"""Tool retrieval using vector search with caching - returns list of tool names"""
"""Tool retrieval using vector search with caching"""
global _cached_vector_store
# Try to load from cache first
@@ -292,7 +312,10 @@ def retrieve_ops_vector(user_query, limit=20):
_build_vector_index()
# Perform similarity search
retrieved_tools = _cached_vector_store.similarity_search(user_query, k=limit)
retrieved_tools = _cached_vector_store.similarity_search(
user_query,
k=limit,
)
retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools]
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
@@ -307,7 +330,11 @@ def retrieve_ops_vector(user_query, limit=20):
return tool_names
async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") -> list:
async def retrieve_ops(
user_query: str,
limit: int = 20,
mode: str = "auto",
) -> list:
"""
Tool retrieval with configurable mode
@@ -322,59 +349,56 @@ async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") ->
Returns:
List of tool names
"""
if mode == "llm":
if mode in ("llm", "auto"):
try:
return await retrieve_ops_lm(user_query, limit=limit)
except Exception as e:
logging.error(f"LLM retrieval failed: {str(e)}")
return []
if mode != "auto":
return []
elif mode == "vector":
if mode in ("vector", "auto"):
try:
return retrieve_ops_vector(user_query, limit=limit)
except Exception as e:
logging.error(f"Vector retrieval failed: {str(e)}")
return []
elif mode == "auto":
try:
return await retrieve_ops_lm(user_query, limit=limit)
except Exception as e:
import traceback
print(traceback.format_exc())
try:
return retrieve_ops_vector(user_query, limit=limit)
except Exception as fallback_e:
logging.error(
f"Tool retrieval failed: {str(e)}, fallback retrieval also failed: {str(fallback_e)}"
)
return []
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'")
raise ValueError(
f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'",
)
if __name__ == "__main__":
import asyncio
user_query = (
"Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms."
+ "Additionally, filter out small images, perform image tagging, and remove duplicate images."
query = (
"Clean special characters from text and filter samples with "
+ "excessive length. Mask sensitive information and filter "
+ "unsafe content including adult/terror-related terms."
+ "Additionally, filter out small images, perform image "
+ "tagging, and remove duplicate images."
)
# Test different modes
print("=== Testing LLM mode ===")
tool_names_llm = asyncio.run(retrieve_ops(user_query, limit=10, mode="llm"))
tool_names_llm = asyncio.run(
retrieve_ops(query, limit=10, mode="llm"),
)
print("Retrieved tool names (LLM):")
print(tool_names_llm)
print("\n=== Testing Vector mode ===")
tool_names_vector = asyncio.run(retrieve_ops(user_query, limit=10, mode="vector"))
tool_names_vector = asyncio.run(
retrieve_ops(query, limit=10, mode="vector"),
)
print("Retrieved tool names (Vector):")
print(tool_names_vector)
print("\n=== Testing Auto mode (default) ===")
tool_names_auto = asyncio.run(retrieve_ops(user_query, limit=10, mode="auto"))
tool_names_auto = asyncio.run(
retrieve_ops(query, limit=10, mode="auto"),
)
print("Retrieved tool names (Auto):")
print(tool_names_auto)

View File

@@ -7,7 +7,9 @@ from agentscope.tool import ToolResponse
def agent_to_tool(
agent: AgentBase, tool_name: str = None, description: str = None
agent: AgentBase,
tool_name: str = None,
description: str = None,
) -> Callable:
"""
Convert any agent to a tool function that can be registered in toolkit.
@@ -15,10 +17,12 @@ def agent_to_tool(
Args:
agent: The agent instance to convert
tool_name: Optional custom tool name (defaults to agent.name)
description: Optional tool description (defaults to agent's docstring or sys_prompt)
description: Optional tool description
(defaults to agent's docstring or sys_prompt)
Returns:
A tool function that can be registered with toolkit.register_tool_function()
A tool function that can be registered with
toolkit.register_tool_function()
"""
# Get tool name and description
if tool_name is None:
@@ -30,8 +34,6 @@ def agent_to_tool(
description = agent.__doc__.strip()
elif hasattr(agent, "sys_prompt"):
description = f"Agent: {agent.sys_prompt[:100]}..."
elif hasattr(agent, "_sys_prompt"):
description = f"Agent: {agent._sys_prompt[:100]}..."
else:
description = f"Tool function for {tool_name}"
@@ -56,7 +58,8 @@ def agent_to_tool(
# Set function name and docstring
tool_function.__name__ = f"call_{tool_name.lower().replace(' ', '_')}"
tool_function.__doc__ = (
f"{description}\n\nArgs:\n task (str): The task for {tool_name} to handle"
f"{description}\n\nArgs:"
+ "\n task (str): The task for {tool_name} to handle"
)
return tool_function