release datajuicer agent

This commit is contained in:
道辕
2025-10-29 18:25:35 +08:00
parent e47349c843
commit 55725959ae
25 changed files with 2219 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
"""
Tools package for data-agent.
This module provides a unified entry point for all agent tools,
organized by agent type for easy access and management.
"""
import asyncio
from typing import List
from agentscope.agent import AgentBase
from agentscope.tool import (
view_text_file,
write_text_file,
)
from agentscope.tool import Toolkit
from .dj_tools import execute_safe_command
from .router_tools import agent_to_tool
from .dj_tools import query_dj_operators
from .dj_dev_tools import get_basic_files, get_operator_example, configure_data_juicer_path
from .mcp_tools import get_mcp_toolkit
def create_toolkit(tools: List[str]):
# Create toolkit and register tools
toolkit = Toolkit()
for tool in tools:
toolkit.register_tool_function(tool)
return toolkit
# DJ Agent tools
dj_tools = [
execute_safe_command,
view_text_file,
write_text_file,
query_dj_operators,
]
# DJ Development Agent tools - for developing DataJuicer operators
dj_dev_tools = [
view_text_file,
write_text_file,
get_basic_files,
get_operator_example,
configure_data_juicer_path,
]
# MCP Agent tools - for advanced data processing with Recipe Flow MCP
mcp_tools = [
view_text_file,
write_text_file,
]
def agents2toolkit(agents: List[AgentBase]):
tools = [agent_to_tool(agent) for agent in agents]
return create_toolkit(tools)
dj_toolkit = create_toolkit(dj_tools)
dj_dev_toolkit = create_toolkit(dj_dev_tools)
# All available tools
all_toolkit = {
"dj": dj_toolkit,
"dj_dev": dj_dev_toolkit,
"dj_mcp": get_mcp_toolkit,
"router": agents2toolkit,
}
# Public API
__all__ = [
"dj_tools",
"dj_dev_tools",
"mcp_tools",
"all_tools",
"agents2toolkit",
"dj_toolkit",
"dj_dev_toolkit",
"get_mcp_toolkit",
# Individual tools for direct import
"execute_safe_command",
"view_text_file",
"write_text_file",
"agent_to_tool",
"query_dj_operators",
"get_basic_files",
"get_operator_example",
"configure_data_juicer_path",
]

View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
"""
DataJuicer Development Tools
Tools for developing DataJuicer operators, including access to basic documentation
and example code for different operator types.
"""
import os
from pathlib import Path
from agentscope.message import TextBlock
from agentscope.tool import ToolResponse
# DataJuicer home path - should be configured based on your environment
DATA_JUICER_PATH = os.getenv("DATA_JUICER_PATH", None)
BASIC_LIST_RELATIVE = [
"data_juicer/ops/base_op.py",
"docs/DeveloperGuide.md",
"docs/DeveloperGuide_ZH.md",
]
def get_basic_files() -> ToolResponse:
"""Get basic DataJuicer development files content.
Returns the content of essential files needed for DJ operator development:
- base_op.py: Base operator class
- DeveloperGuide.md: English developer guide
- DeveloperGuide_ZH.md: Chinese developer guide
Returns:
ToolResponse: Combined content of all basic development files
"""
global DATA_JUICER_PATH, BASIC_LIST_RELATIVE
if DATA_JUICER_PATH is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
)
]
)
try:
combined_content = "# DataJuicer Operator Development Basic Files\n\n"
for relative_path in BASIC_LIST_RELATIVE:
file_path = os.path.join(DATA_JUICER_PATH, relative_path)
if os.path.exists(file_path):
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
filename = os.path.basename(file_path)
combined_content += f"## {filename}\n\n"
combined_content += (
f"```{'python' if filename.endswith('.py') else 'markdown'}\n"
)
combined_content += content
combined_content += "\n```\n\n"
except Exception as e:
combined_content += (
f"## {os.path.basename(file_path)} (Read Failed)\n"
)
combined_content += f"Error: {str(e)}\n\n"
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error occurred while getting basic files: {str(e)}",
)
]
)
async def get_operator_example(
requirement_description: str, limit: int = 2
) -> ToolResponse:
"""Get example operators based on requirement description using dynamic search.
Args:
requirement_description (str): Natural language description of the operator requirement
limit (int): Maximum number of example operators to return (default: 2)
Returns:
ToolResponse: Example operator code and test files based on the requirement
"""
global DATA_JUICER_PATH
if DATA_JUICER_PATH is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
)
]
)
try:
# Import retrieve_ops from op_manager
from .op_manager.op_retrieval import retrieve_ops
# Query relevant operators using the requirement description
# Use retrieval mode from environment variable if set
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
tool_names = await retrieve_ops(requirement_description, limit=limit, mode=retrieval_mode)
if not tool_names:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"No relevant operators found for requirement: {requirement_description}\n"
f"Please try with more specific keywords or check if DATA_JUICER_PATH is properly configured.",
)
]
)
combined_content = (
f"# Dynamic Operator Examples for: {requirement_description}\n\n"
)
combined_content += (
f"Found {len(tool_names)} relevant operators (limit: {limit})\n\n"
)
# Process each found operator
for i, tool_name in enumerate(tool_names[:limit]):
combined_content += f"## {i+1}. {tool_name}\n\n"
op_type = tool_name.split("_")[-1]
operator_path = f"data_juicer/ops/{op_type}/{tool_name}.py"
# Try to find operator source file
full_path = os.path.join(DATA_JUICER_PATH, operator_path)
if os.path.exists(full_path):
with open(full_path, "r", encoding="utf-8") as f:
operator_code = f.read()
combined_content += f"### Source Code\n"
combined_content += "```python\n"
combined_content += operator_code
combined_content += "\n```\n\n"
else:
combined_content += (
f"**Note:** Source code file not found for `{tool_name}`.\n\n"
)
test_path = f"tests/ops/{op_type}/test_{tool_name}.py"
full_test_path = os.path.join(DATA_JUICER_PATH, test_path)
if os.path.exists(full_test_path):
with open(full_test_path, "r", encoding="utf-8") as f:
test_code = f.read()
combined_content += f"### Test Code\n"
combined_content += f"**File Path:** `{test_path}`\n\n"
combined_content += "```python\n"
combined_content += test_code
combined_content += "\n```\n\n"
else:
combined_content += (
f"**Note:** Test file not found for `{tool_name}`.\n\n"
)
combined_content += "---\n\n"
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error occurred while getting operator examples: {str(e)}\n"
f"Please check the requirement description and try again.",
)
]
)
def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
"""Configure DataJuicer path.
If the user provides the data_juicer_path, please use this method to configure it.
Args:
data_juicer_path (str): Path to DataJuicer installation
Returns:
ToolResponse: Configuration result
"""
global DATA_JUICER_PATH
data_juicer_path = os.path.expanduser(data_juicer_path)
try:
if not os.path.exists(data_juicer_path):
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Specified DataJuicer path does not exist: {data_juicer_path}",
)
]
)
# Update global DATA_JUICER_PATH
DATA_JUICER_PATH = data_juicer_path
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"DataJuicer path has been updated to: {DATA_JUICER_PATH}",
)
]
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error occurred while configuring DataJuicer path: {str(e)}",
)
]
)

View File

@@ -0,0 +1,224 @@
import os
import os.path as osp
import json
import asyncio
from typing import Any
from agentscope.message import TextBlock
from agentscope.tool import ToolResponse
from .op_manager.op_retrieval import retrieve_ops
# Load tool information for formatting
TOOLS_INFO_PATH = osp.join(osp.dirname(__file__), "op_manager", "dj_funcs_all.json")
def _load_tools_info():
"""Load tools information from JSON file or create it if not exists"""
if osp.exists(TOOLS_INFO_PATH):
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
return json.loads(f.read())
else:
from .op_manager.create_dj_func_info import dj_func_info
with open(TOOLS_INFO_PATH, "w", encoding="utf-8") as f:
json.dump(dj_func_info, f)
return dj_func_info
def _format_tool_names_to_class_entries(tool_names):
"""Convert tool names list to formatted class entries string"""
if not tool_names:
return ""
tools_info = _load_tools_info()
# Create a mapping from class_name to tool info for quick lookup
tools_map = {tool['class_name']: tool for tool in tools_info}
formatted_entries = []
for i, tool_name in enumerate(tool_names):
if tool_name in tools_map:
tool_info = tools_map[tool_name]
class_entry = f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
class_entry += "\n" + tool_info["arguments"]
formatted_entries.append(class_entry)
return "\n".join(formatted_entries)
async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
"""Query DataJuicer operators by natural language description.
Retrieves relevant operators from DataJuicer library based on user query.
Supports matching by functionality, data type, and processing scenarios.
Args:
query (str): Natural language operator query
limit (int): Maximum number of operators to return (default: 20)
Returns:
ToolResponse: Tool response containing matched operators with names, descriptions, and parameters
"""
try:
# Retrieve operator names using existing functionality with limit
# Use retrieval mode from environment variable if set
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
tool_names = await retrieve_ops(query, limit=limit, mode=retrieval_mode)
if not tool_names:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"No matching DataJuicer operators found for query: {query}\n"
f"Suggestions:\n"
f"1. Use more specific keywords like 'text filter', 'image processing'\n"
f"2. Check spelling and try alternative terms\n"
f"3. Try English keywords for better matching",
)
],
)
# Format tool names to class entries
retrieved_operators = _format_tool_names_to_class_entries(tool_names)
# Format response
result_text = f"🔍 DataJuicer Operator Query Results\n"
result_text += f"Query: {query}\n"
result_text += f"Limit: {limit} operators\n"
result_text += f"{'='*50}\n\n"
result_text += retrieved_operators
return ToolResponse(
content=[
TextBlock(
type="text",
text=result_text,
)
],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error querying DataJuicer operators: {str(e)}\n"
f"Please verify query parameters and retry.",
)
],
)
async def execute_safe_command(
command: str,
timeout: int = 300,
**kwargs: Any,
) -> ToolResponse:
"""Execute safe commands including DataJuicer commands and other safe system commands.
Returns the return code, standard output and error within <returncode></returncode>,
<stdout></stdout> and <stderr></stderr> tags.
Args:
command (`str`):
The command to execute. Allowed commands include:
- DataJuicer commands: dj-process, dj-analyze
- File system commands: mkdir, ls, pwd, cat, echo, cp, mv, rm
- Text processing: grep, head, tail, wc, sort, uniq
- Archive commands: tar, zip, unzip
- Other safe commands: which, whoami, date, find
timeout (`float`, defaults to `300`):
The maximum time (in seconds) allowed for the command to run.
Returns:
`ToolResponse`:
The tool response containing the return code, standard output, and
standard error of the executed command.
"""
# Security check: only allow safe commands
command_stripped = command.strip()
# Define allowed command prefixes for security
allowed_commands = [
# DataJuicer commands
'dj-process', 'dj-analyze',
# File system operations
'mkdir', 'ls', 'pwd', 'cat', 'echo', 'cp', 'mv', 'rm',
# Text processing
'grep', 'head', 'tail', 'wc', 'sort', 'uniq',
# Archive operations
'tar', 'zip', 'unzip',
# Information commands
'which', 'whoami', 'date', 'find',
# Python commands
'python', 'python3', 'pip', 'uv'
]
# Check if command starts with any allowed command
command_allowed = False
for allowed_cmd in allowed_commands:
if command_stripped.startswith(allowed_cmd):
# Additional security checks for potentially dangerous commands
if allowed_cmd in ['rm', 'mv'] and ('/' in command_stripped or '..' in command_stripped):
# Prevent dangerous path operations
continue
command_allowed = True
break
if not command_allowed:
error_msg = f"Error: Command not allowed for security reasons. Allowed commands: {', '.join(allowed_commands)}. Received command: {command}"
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"<returncode>-1</returncode>"
f"<stdout></stdout>"
f"<stderr>{error_msg}</stderr>"
),
),
],
)
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
bufsize=0,
)
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
stdout, stderr = await proc.communicate()
stdout_str = stdout.decode("utf-8")
stderr_str = stderr.decode("utf-8")
returncode = proc.returncode
except asyncio.TimeoutError:
stderr_suffix = (
f"TimeoutError: The command execution exceeded "
f"the timeout of {timeout} seconds."
)
returncode = -1
try:
proc.terminate()
stdout, stderr = await proc.communicate()
stdout_str = stdout.decode("utf-8")
stderr_str = stderr.decode("utf-8")
if stderr_str:
stderr_str += f"\n{stderr_suffix}"
else:
stderr_str = stderr_suffix
except ProcessLookupError:
stdout_str = ""
stderr_str = stderr_suffix
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"<returncode>{returncode}</returncode>"
f"<stdout>{stdout_str}</stdout>"
f"<stderr>{stderr_str}</stderr>"
),
),
],
)

View File

@@ -0,0 +1,120 @@
import json
import os
import logging
from typing import Optional, List
import string
from agentscope.tool import Toolkit
from agentscope.mcp import HttpStatefulClient, HttpStatelessClient, StdIOStatefulClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
root_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
def _load_config(config_path: str) -> dict:
"""Load MCP configuration from file"""
try:
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
logger.info(f"Loaded MCP configuration from {config_path}")
return config
else:
logger.warning(
f"Configuration file {config_path} not found, using default settings"
)
return _create_default_config()
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return _create_default_config()
def _create_default_config() -> dict:
"""Create default configuration"""
return {
"mcpServers": {
"dj_recipe_flow": {
"command": "python",
"args": ["/home/test/data_juicer/tools/DJ_mcp_recipe_flow.py"],
"env": {"SERVER_TRANSPORT": "stdio"},
}
}
}
def _expand_env_vars(value: str) -> str:
"""Expand environment variables in configuration values"""
if isinstance(value, str):
template = string.Template(value)
try:
return template.substitute(os.environ)
except KeyError as e:
logger.warning(f"Environment variable not found: {e}")
return value
return value
async def _create_clients(config: dict, toolkit: Toolkit):
"""Create MCP clients based on configuration"""
server_configs = config.get("mcpServers", {})
clients = []
for server_name, server_config in server_configs.items():
try:
# Handle StdIO client
if "command" in server_config:
command = server_config["command"]
args = server_config.get("args", [])
env = server_config.get("env", {})
# Expand environment variables
expanded_args = [_expand_env_vars(arg) for arg in args]
expanded_env = {k: _expand_env_vars(v) for k, v in env.items()}
client = StdIOStatefulClient(
name=server_name,
command=command,
args=expanded_args,
env=expanded_env,
)
await client.connect()
await toolkit.register_mcp_client(client)
# Handle HTTP clients
elif "url" in server_config:
url = _expand_env_vars(server_config["url"])
transport = server_config.get("transport", "sse")
stateful = server_config.get("stateful", True)
if stateful:
client = HttpStatefulClient(
name=server_name, transport=transport, url=url
)
await client.connect()
await toolkit.register_mcp_client(client)
else:
client = HttpStatelessClient(
name=server_name, transport=transport, url=url
)
await toolkit.register_mcp_client(client)
else:
raise ValueError("Invalid server configuration")
clients.append(client)
except Exception as e:
if "Invalid server configuration" in str(e):
raise e
logger.error(f"Failed to create client {server_name}: {e}")
return clients
async def get_mcp_toolkit(config_path: Optional[str] = None) -> Toolkit:
"""Get toolkit with all MCP tools registered"""
config_path = config_path or root_path + "/configs/mcp_config.json"
config = _load_config(config_path)
toolkit = Toolkit()
clients = await _create_clients(config, toolkit)
return toolkit, clients

View File

@@ -0,0 +1,34 @@
import inspect
from data_juicer.tools.op_search import OPSearcher
searcher = OPSearcher(include_formatter=False)
all_ops = searcher.search()
dj_func_info = []
for i, op in enumerate(all_ops):
class_entry = {"index": i, "class_name": op["name"], "class_desc": op["desc"]}
param_desc = op["param_desc"]
param_desc_map = {}
args = ""
for item in param_desc.split(":param"):
_item = item.split(":")
if len(_item) < 2:
continue
param_desc_map[_item[0].strip()] = ":".join(_item[1:]).strip()
if op["sig"]:
for param_name, param in op["sig"].parameters.items():
if param_name in ["self", "args", "kwargs"]:
continue
if param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
continue
if param_name in param_desc_map:
args += f" {param_name} ({param.annotation}): {param_desc_map[param_name]}\n"
else:
args += f" {param_name} ({param.annotation})\n"
class_entry["arguments"] = args
dj_func_info.append(class_entry)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,380 @@
import os
import os.path as osp
import json
import logging
import pickle
import hashlib
import time
from typing import Optional
from langchain_community.vectorstores import FAISS
TOOLS_INFO_PATH = osp.join(osp.dirname(__file__), "dj_funcs_all.json")
CACHE_RETRIEVED_TOOLS_PATH = osp.join(osp.dirname(__file__), "cache_retrieve")
VECTOR_INDEX_CACHE_PATH = osp.join(osp.dirname(__file__), "vector_index_cache")
# Global variable to cache the vector store
_cached_vector_store: Optional[FAISS] = None
_cached_tools_info: Optional[list] = None
_cached_file_hash: Optional[str] = None
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:
# Requirement Analysis
Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).
# Tool Matching
Perform multi-dimensional matching based on the following tool attributes:
- Tool name and functional description
- Supported input/output formats
- Applicable industry or scenario tags
- Technical implementation principles (API, local deployment, AI model types)
- Relevance ranking
# Use weighted scoring mechanism (example weights):
- Functional match (40%)
- Scenario compatibility (30%)
- Technical compatibility (20%)
- User rating/usage rate (10%)
# Deduplication and Optimization
Exclude the following low-quality results:
- Tools with duplicate functionality (keep only the best one)
- Tools that cannot meet basic requirements
- Tools missing critical parameter descriptions
# Constraints
- Strictly control output to a maximum of {limit} tools
- Refuse to speculate on unknown tool attributes
- Maintain accuracy of domain expertise
# Output Format
Return a JSON format TOP{limit} tool list containing:
[
{{
"rank": 1,
"tool_name": "Tool Name",
"description": "Core functionality summary",
"relevance_score": 98.7,
"key_match": ["Matching keywords/features"]
}}
]
Output strictly in JSON array format, and only output the JSON array format tool list.
"""
def fast_text_encoder(text: str) -> str:
"""Fast encoding using xxHash algorithm"""
import xxhash
hasher = xxhash.xxh64(seed=0)
hasher.update(text.encode("utf-8"))
# Return 16-bit hexadecimal string
return hasher.hexdigest()
async def retrieve_ops_lm(user_query, limit=20):
"""Tool retrieval using language model - returns list of tool names"""
hash_id = fast_text_encoder(user_query + str(limit))
# Ensure cache directory exists
os.makedirs(CACHE_RETRIEVED_TOOLS_PATH, exist_ok=True)
cache_tools_path = osp.join(CACHE_RETRIEVED_TOOLS_PATH, f"{hash_id}.json")
if osp.exists(cache_tools_path):
with open(cache_tools_path, "r", encoding="utf-8") as f:
return json.loads(f.read())
if osp.exists(TOOLS_INFO_PATH):
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
dj_func_info = json.loads(f.read())
tool_descriptions = [
f"{t['class_name']}: {t['class_desc']}" for t in dj_func_info
]
tools_string = "\n".join(tool_descriptions)
else:
from create_dj_func_info import dj_func_info
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
with open(os.path.join(project_root, TOOLS_INFO_PATH), "w") as f:
f.write(json.dumps(dj_func_info))
tool_descriptions = [
f"{t['class_name']}: {t['class_desc']}" for t in dj_func_info
]
tools_string = "\n".join(tool_descriptions)
from agentscope.model import DashScopeChatModel
from agentscope.message import Msg
from agentscope.formatter import DashScopeChatFormatter
model = DashScopeChatModel(
model_name="qwen-turbo",
api_key=os.environ.get("DASHSCOPE_API_KEY"),
stream=False,
)
formatter = DashScopeChatFormatter()
# Update retrieval prompt to use the specified limit
retrieval_prompt_with_limit = RETRIEVAL_PROMPT.format(limit=limit)
user_prompt = (
retrieval_prompt_with_limit
+ """
User requirement description:
{user_query}
Available tools:
{tools_string}
""".format(
user_query=user_query, tools_string=tools_string
)
)
msgs = [
Msg(name="user", role="user", content=user_prompt),
]
formatted_msgs = await formatter.format(msgs)
response = await model(formatted_msgs)
msg = Msg(name="assistant", role="assistant", content=response.content)
retrieved_tools_text = msg.get_text_content()
retrieved_tools = json.loads(retrieved_tools_text)
# Extract tool names and validate they exist
tool_names = []
for tool_info in retrieved_tools:
if not isinstance(tool_info, dict) or "tool_name" not in tool_info:
logging.warning(f"Invalid tool info format: {tool_info}")
continue
tool_name = tool_info["tool_name"]
# Verify tool exists in dj_func_info
tool_exists = any(t["class_name"] == tool_name for t in dj_func_info)
if not tool_exists:
logging.error(f"Tool not found: `{tool_name}`, skipping!")
continue
tool_names.append(tool_name)
# Cache the result
with open(cache_tools_path, "w", encoding="utf-8") as f:
json.dump(tool_names, f)
return tool_names
def _get_file_hash(file_path: str) -> str:
"""Get file content hash using SHA256"""
try:
with open(file_path, "rb") as f:
file_content = f.read()
return hashlib.sha256(file_content).hexdigest()
except (OSError, IOError):
return ""
def _load_cached_index() -> bool:
"""Load cached vector index from disk"""
global _cached_vector_store, _cached_tools_info, _cached_file_hash
try:
# Ensure cache directory exists
os.makedirs(VECTOR_INDEX_CACHE_PATH, exist_ok=True)
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
if not all(
os.path.exists(p) for p in [index_path, metadata_path]
):
return False
# Check if cached index matches current tools info file
with open(metadata_path, "r") as f:
metadata = json.load(f)
cached_hash = metadata.get("tools_info_hash", "")
current_hash = _get_file_hash(TOOLS_INFO_PATH)
if current_hash != cached_hash:
return False
# Load cached data
from langchain_community.embeddings import DashScopeEmbeddings
embeddings = DashScopeEmbeddings(
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
model="text-embedding-v1",
)
_cached_vector_store = FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True
)
_cached_file_hash = cached_hash
logging.info("Successfully loaded cached vector index")
return True
except Exception as e:
logging.warning(f"Failed to load cached index: {e}")
return False
def _save_cached_index():
"""Save vector index to disk cache"""
global _cached_vector_store, _cached_file_hash
try:
# Ensure cache directory exists
os.makedirs(VECTOR_INDEX_CACHE_PATH, exist_ok=True)
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
# Save vector store
if _cached_vector_store:
_cached_vector_store.save_local(index_path)
# Save metadata
metadata = {"tools_info_hash": _cached_file_hash, "created_at": time.time()}
with open(metadata_path, "w") as f:
json.dump(metadata, f)
logging.info("Successfully saved vector index to cache")
except Exception as e:
logging.error(f"Failed to save cached index: {e}")
def _build_vector_index():
"""Build and cache vector index"""
global _cached_vector_store, _cached_file_hash
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
tools_info = json.loads(f.read())
tool_descriptions = [f"{t['class_name']}: {t['class_desc']}" for t in tools_info]
from langchain_community.embeddings import DashScopeEmbeddings
embeddings = DashScopeEmbeddings(
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v1"
)
metadatas = [{"index": i} for i in range(len(tool_descriptions))]
vector_store = FAISS.from_texts(tool_descriptions, embeddings, metadatas=metadatas)
# Cache the results
_cached_vector_store = vector_store
_cached_file_hash = _get_file_hash(TOOLS_INFO_PATH)
# Save to disk cache
_save_cached_index()
logging.info("Successfully built and cached vector index")
def retrieve_ops_vector(user_query, limit=20):
"""Tool retrieval using vector search with caching - returns list of tool names"""
global _cached_vector_store
# Try to load from cache first
if not _load_cached_index():
logging.info("Building new vector index...")
_build_vector_index()
# Perform similarity search
retrieved_tools = _cached_vector_store.similarity_search(user_query, k=limit)
retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools]
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
tools_info = json.loads(f.read())
# Extract tool names from retrieved indices
tool_names = []
for raw_idx in retrieved_indices:
tool_info = tools_info[raw_idx]
tool_names.append(tool_info["class_name"])
return tool_names
async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") -> list:
"""
Tool retrieval with configurable mode
Args:
user_query: User query string
limit: Maximum number of tools to retrieve
mode: Retrieval mode - "llm", "vector", or "auto" (default: "auto")
- "llm": Use language model only
- "vector": Use vector search only
- "auto": Try LLM first, fallback to vector search on failure
Returns:
List of tool names
"""
if mode == "llm":
try:
return await retrieve_ops_lm(user_query, limit=limit)
except Exception as e:
logging.error(f"LLM retrieval failed: {str(e)}")
return []
elif mode == "vector":
try:
return retrieve_ops_vector(user_query, limit=limit)
except Exception as e:
logging.error(f"Vector retrieval failed: {str(e)}")
return []
elif mode == "auto":
try:
return await retrieve_ops_lm(user_query, limit=limit)
except Exception as e:
import traceback
print(traceback.format_exc())
try:
return retrieve_ops_vector(user_query, limit=limit)
except Exception as fallback_e:
logging.error(
f"Tool retrieval failed: {str(e)}, fallback retrieval also failed: {str(fallback_e)}"
)
return []
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'")
if __name__ == "__main__":
import asyncio
user_query = (
"Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms."
+ "Additionally, filter out small images, perform image tagging, and remove duplicate images."
)
# Test different modes
print("=== Testing LLM mode ===")
tool_names_llm = asyncio.run(retrieve_ops(user_query, limit=10, mode="llm"))
print("Retrieved tool names (LLM):")
print(tool_names_llm)
print("\n=== Testing Vector mode ===")
tool_names_vector = asyncio.run(retrieve_ops(user_query, limit=10, mode="vector"))
print("Retrieved tool names (Vector):")
print(tool_names_vector)
print("\n=== Testing Auto mode (default) ===")
tool_names_auto = asyncio.run(retrieve_ops(user_query, limit=10, mode="auto"))
print("Retrieved tool names (Auto):")
print(tool_names_auto)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""Router agent using implicit routing"""
from typing import Callable
from agentscope.agent import AgentBase
from agentscope.message import Msg
from agentscope.tool import ToolResponse
def agent_to_tool(
agent: AgentBase, tool_name: str = None, description: str = None
) -> Callable:
"""
Convert any agent to a tool function that can be registered in toolkit.
Args:
agent: The agent instance to convert
tool_name: Optional custom tool name (defaults to agent.name)
description: Optional tool description (defaults to agent's docstring or sys_prompt)
Returns:
A tool function that can be registered with toolkit.register_tool_function()
"""
# Get tool name and description
if tool_name is None:
tool_name = getattr(agent, "name", "agent_tool")
if description is None:
# Try to get description from agent's docstring or sys_prompt
if hasattr(agent, "__doc__") and agent.__doc__:
description = agent.__doc__.strip()
elif hasattr(agent, "sys_prompt"):
description = f"Agent: {agent.sys_prompt[:100]}..."
elif hasattr(agent, "_sys_prompt"):
description = f"Agent: {agent._sys_prompt[:100]}..."
else:
description = f"Tool function for {tool_name}"
async def tool_function(task: str) -> ToolResponse:
# Create message and call the agent
msg = Msg("user", task, "user")
result = await agent(msg)
# Extract content from the result
if hasattr(result, "get_content_blocks"):
content = result.get_content_blocks("text")
return ToolResponse(
content=content,
metadata={
"agent_name": getattr(agent, "name", "unknown"),
"task": task,
},
)
else:
raise ValueError(f"Not a valid Msg object: {result}")
# Set function name and docstring
tool_function.__name__ = f"call_{tool_name.lower().replace(' ', '_')}"
tool_function.__doc__ = (
f"{description}\n\nArgs:\n task (str): The task for {tool_name} to handle"
)
return tool_function