Optimize DataJuicer Agent doc & linter (#30)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import inspect
|
||||
from data_juicer.tools.op_search import OPSearcher
|
||||
|
||||
@@ -7,7 +8,11 @@ all_ops = searcher.search()
|
||||
|
||||
dj_func_info = []
|
||||
for i, op in enumerate(all_ops):
|
||||
class_entry = {"index": i, "class_name": op["name"], "class_desc": op["desc"]}
|
||||
class_entry = {
|
||||
"index": i,
|
||||
"class_name": op["name"],
|
||||
"class_desc": op["desc"],
|
||||
}
|
||||
param_desc = op["param_desc"]
|
||||
param_desc_map = {}
|
||||
args = ""
|
||||
@@ -27,7 +32,8 @@ for i, op in enumerate(all_ops):
|
||||
):
|
||||
continue
|
||||
if param_name in param_desc_map:
|
||||
args += f" {param_name} ({param.annotation}): {param_desc_map[param_name]}\n"
|
||||
args += f" {param_name} ({param.annotation}):"
|
||||
args += f" {param_desc_map[param_name]}\n"
|
||||
else:
|
||||
args += f" {param_name} ({param.annotation})\n"
|
||||
class_entry["arguments"] = args
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Optional
|
||||
@@ -18,17 +18,22 @@ _cached_vector_store: Optional[FAISS] = None
|
||||
_cached_tools_info: Optional[list] = None
|
||||
_cached_file_hash: Optional[str] = None
|
||||
|
||||
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:
|
||||
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant
|
||||
responsible for filtering the top {limit} most relevant tools from a large
|
||||
tool library based on user requirements. Execute the following steps:
|
||||
|
||||
# Requirement Analysis
|
||||
Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).
|
||||
Carefully read the user's [requirement description], extract core keywords,
|
||||
functional objectives, usage scenarios, and technical requirements
|
||||
(such as real-time performance, data types, industry domains, etc.).
|
||||
|
||||
# Tool Matching
|
||||
Perform multi-dimensional matching based on the following tool attributes:
|
||||
- Tool name and functional description
|
||||
- Supported input/output formats
|
||||
- Applicable industry or scenario tags
|
||||
- Technical implementation principles (API, local deployment, AI model types)
|
||||
- Technical implementation principles
|
||||
(API, local deployment, AI model types)
|
||||
- Relevance ranking
|
||||
|
||||
# Use weighted scoring mechanism (example weights):
|
||||
@@ -59,7 +64,8 @@ RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsibl
|
||||
"key_match": ["Matching keywords/features"]
|
||||
}}
|
||||
]
|
||||
Output strictly in JSON array format, and only output the JSON array format tool list.
|
||||
Output strictly in JSON array format, and only output the JSON array format
|
||||
tool list.
|
||||
"""
|
||||
|
||||
|
||||
@@ -96,9 +102,15 @@ async def retrieve_ops_lm(user_query, limit=20):
|
||||
else:
|
||||
from create_dj_func_info import dj_func_info
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
project_root = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), ".."),
|
||||
)
|
||||
|
||||
with open(os.path.join(project_root, TOOLS_INFO_PATH), "w") as f:
|
||||
with open(
|
||||
os.path.join(project_root, TOOLS_INFO_PATH),
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
f.write(json.dumps(dj_func_info))
|
||||
|
||||
tool_descriptions = [
|
||||
@@ -123,15 +135,13 @@ async def retrieve_ops_lm(user_query, limit=20):
|
||||
|
||||
user_prompt = (
|
||||
retrieval_prompt_with_limit
|
||||
+ """
|
||||
User requirement description:
|
||||
+ f"""
|
||||
User requirement description:
|
||||
{user_query}
|
||||
|
||||
Available tools:
|
||||
{tools_string}
|
||||
""".format(
|
||||
user_query=user_query, tools_string=tools_string
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
msgs = [
|
||||
@@ -191,13 +201,11 @@ def _load_cached_index() -> bool:
|
||||
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
|
||||
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
|
||||
|
||||
if not all(
|
||||
os.path.exists(p) for p in [index_path, metadata_path]
|
||||
):
|
||||
if not all(os.path.exists(p) for p in [index_path, metadata_path]):
|
||||
return False
|
||||
|
||||
# Check if cached index matches current tools info file
|
||||
with open(metadata_path, "r") as f:
|
||||
with open(metadata_path, "r", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
cached_hash = metadata.get("tools_info_hash", "")
|
||||
@@ -215,7 +223,9 @@ def _load_cached_index() -> bool:
|
||||
)
|
||||
|
||||
_cached_vector_store = FAISS.load_local(
|
||||
index_path, embeddings, allow_dangerous_deserialization=True
|
||||
index_path,
|
||||
embeddings,
|
||||
allow_dangerous_deserialization=True,
|
||||
)
|
||||
|
||||
_cached_file_hash = cached_hash
|
||||
@@ -244,8 +254,11 @@ def _save_cached_index():
|
||||
_cached_vector_store.save_local(index_path)
|
||||
|
||||
# Save metadata
|
||||
metadata = {"tools_info_hash": _cached_file_hash, "created_at": time.time()}
|
||||
with open(metadata_path, "w") as f:
|
||||
metadata = {
|
||||
"tools_info_hash": _cached_file_hash,
|
||||
"created_at": time.time(),
|
||||
}
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
logging.info("Successfully saved vector index to cache")
|
||||
@@ -261,16 +274,23 @@ def _build_vector_index():
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
tools_info = json.loads(f.read())
|
||||
|
||||
tool_descriptions = [f"{t['class_name']}: {t['class_desc']}" for t in tools_info]
|
||||
tool_descriptions = [
|
||||
f"{t['class_name']}: {t['class_desc']}" for t in tools_info
|
||||
]
|
||||
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
|
||||
embeddings = DashScopeEmbeddings(
|
||||
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v1"
|
||||
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
|
||||
model="text-embedding-v1",
|
||||
)
|
||||
|
||||
metadatas = [{"index": i} for i in range(len(tool_descriptions))]
|
||||
vector_store = FAISS.from_texts(tool_descriptions, embeddings, metadatas=metadatas)
|
||||
vector_store = FAISS.from_texts(
|
||||
tool_descriptions,
|
||||
embeddings,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
# Cache the results
|
||||
_cached_vector_store = vector_store
|
||||
@@ -283,7 +303,7 @@ def _build_vector_index():
|
||||
|
||||
|
||||
def retrieve_ops_vector(user_query, limit=20):
|
||||
"""Tool retrieval using vector search with caching - returns list of tool names"""
|
||||
"""Tool retrieval using vector search with caching"""
|
||||
global _cached_vector_store
|
||||
|
||||
# Try to load from cache first
|
||||
@@ -292,7 +312,10 @@ def retrieve_ops_vector(user_query, limit=20):
|
||||
_build_vector_index()
|
||||
|
||||
# Perform similarity search
|
||||
retrieved_tools = _cached_vector_store.similarity_search(user_query, k=limit)
|
||||
retrieved_tools = _cached_vector_store.similarity_search(
|
||||
user_query,
|
||||
k=limit,
|
||||
)
|
||||
retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools]
|
||||
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
@@ -307,7 +330,11 @@ def retrieve_ops_vector(user_query, limit=20):
|
||||
return tool_names
|
||||
|
||||
|
||||
async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") -> list:
|
||||
async def retrieve_ops(
|
||||
user_query: str,
|
||||
limit: int = 20,
|
||||
mode: str = "auto",
|
||||
) -> list:
|
||||
"""
|
||||
Tool retrieval with configurable mode
|
||||
|
||||
@@ -322,59 +349,56 @@ async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") ->
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
if mode == "llm":
|
||||
if mode in ("llm", "auto"):
|
||||
try:
|
||||
return await retrieve_ops_lm(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
logging.error(f"LLM retrieval failed: {str(e)}")
|
||||
return []
|
||||
if mode != "auto":
|
||||
return []
|
||||
|
||||
elif mode == "vector":
|
||||
if mode in ("vector", "auto"):
|
||||
try:
|
||||
return retrieve_ops_vector(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
logging.error(f"Vector retrieval failed: {str(e)}")
|
||||
return []
|
||||
|
||||
elif mode == "auto":
|
||||
try:
|
||||
return await retrieve_ops_lm(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
try:
|
||||
return retrieve_ops_vector(user_query, limit=limit)
|
||||
except Exception as fallback_e:
|
||||
logging.error(
|
||||
f"Tool retrieval failed: {str(e)}, fallback retrieval also failed: {str(fallback_e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'")
|
||||
raise ValueError(
|
||||
f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
user_query = (
|
||||
"Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms."
|
||||
+ "Additionally, filter out small images, perform image tagging, and remove duplicate images."
|
||||
query = (
|
||||
"Clean special characters from text and filter samples with "
|
||||
+ "excessive length. Mask sensitive information and filter "
|
||||
+ "unsafe content including adult/terror-related terms."
|
||||
+ "Additionally, filter out small images, perform image "
|
||||
+ "tagging, and remove duplicate images."
|
||||
)
|
||||
|
||||
# Test different modes
|
||||
print("=== Testing LLM mode ===")
|
||||
tool_names_llm = asyncio.run(retrieve_ops(user_query, limit=10, mode="llm"))
|
||||
tool_names_llm = asyncio.run(
|
||||
retrieve_ops(query, limit=10, mode="llm"),
|
||||
)
|
||||
print("Retrieved tool names (LLM):")
|
||||
print(tool_names_llm)
|
||||
|
||||
print("\n=== Testing Vector mode ===")
|
||||
tool_names_vector = asyncio.run(retrieve_ops(user_query, limit=10, mode="vector"))
|
||||
tool_names_vector = asyncio.run(
|
||||
retrieve_ops(query, limit=10, mode="vector"),
|
||||
)
|
||||
print("Retrieved tool names (Vector):")
|
||||
print(tool_names_vector)
|
||||
|
||||
print("\n=== Testing Auto mode (default) ===")
|
||||
tool_names_auto = asyncio.run(retrieve_ops(user_query, limit=10, mode="auto"))
|
||||
tool_names_auto = asyncio.run(
|
||||
retrieve_ops(query, limit=10, mode="auto"),
|
||||
)
|
||||
print("Retrieved tool names (Auto):")
|
||||
print(tool_names_auto)
|
||||
|
||||
Reference in New Issue
Block a user