Optimize DataJuicer Agent doc & linter (#30)

This commit is contained in:
Daoyuan Chen
2025-11-10 18:17:27 +08:00
committed by GitHub
parent 1f0c5de27f
commit dba3b86ddf
14 changed files with 891 additions and 359 deletions

View File

@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import inspect
from data_juicer.tools.op_search import OPSearcher
@@ -7,7 +8,11 @@ all_ops = searcher.search()
dj_func_info = []
for i, op in enumerate(all_ops):
class_entry = {"index": i, "class_name": op["name"], "class_desc": op["desc"]}
class_entry = {
"index": i,
"class_name": op["name"],
"class_desc": op["desc"],
}
param_desc = op["param_desc"]
param_desc_map = {}
args = ""
@@ -27,7 +32,8 @@ for i, op in enumerate(all_ops):
):
continue
if param_name in param_desc_map:
args += f" {param_name} ({param.annotation}): {param_desc_map[param_name]}\n"
args += f" {param_name} ({param.annotation}):"
args += f" {param_desc_map[param_name]}\n"
else:
args += f" {param_name} ({param.annotation})\n"
class_entry["arguments"] = args

View File

@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
import os
import os.path as osp
import json
import logging
import pickle
import hashlib
import time
from typing import Optional
@@ -18,17 +18,22 @@ _cached_vector_store: Optional[FAISS] = None
_cached_tools_info: Optional[list] = None
_cached_file_hash: Optional[str] = None
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant
responsible for filtering the top {limit} most relevant tools from a large
tool library based on user requirements. Execute the following steps:
# Requirement Analysis
Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).
Carefully read the user's [requirement description], extract core keywords,
functional objectives, usage scenarios, and technical requirements
(such as real-time performance, data types, industry domains, etc.).
# Tool Matching
Perform multi-dimensional matching based on the following tool attributes:
- Tool name and functional description
- Supported input/output formats
- Applicable industry or scenario tags
- Technical implementation principles (API, local deployment, AI model types)
- Technical implementation principles
(API, local deployment, AI model types)
- Relevance ranking
# Use weighted scoring mechanism (example weights):
@@ -59,7 +64,8 @@ RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsibl
"key_match": ["Matching keywords/features"]
}}
]
Output strictly in JSON array format, and only output the JSON array format tool list.
Output strictly in JSON array format, and only output the JSON array format
tool list.
"""
@@ -96,9 +102,15 @@ async def retrieve_ops_lm(user_query, limit=20):
else:
from create_dj_func_info import dj_func_info
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
project_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
)
with open(os.path.join(project_root, TOOLS_INFO_PATH), "w") as f:
with open(
os.path.join(project_root, TOOLS_INFO_PATH),
"w",
encoding="utf-8",
) as f:
f.write(json.dumps(dj_func_info))
tool_descriptions = [
@@ -123,15 +135,13 @@ async def retrieve_ops_lm(user_query, limit=20):
user_prompt = (
retrieval_prompt_with_limit
+ """
User requirement description:
+ f"""
User requirement description:
{user_query}
Available tools:
{tools_string}
""".format(
user_query=user_query, tools_string=tools_string
)
"""
)
msgs = [
@@ -191,13 +201,11 @@ def _load_cached_index() -> bool:
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
if not all(
os.path.exists(p) for p in [index_path, metadata_path]
):
if not all(os.path.exists(p) for p in [index_path, metadata_path]):
return False
# Check if cached index matches current tools info file
with open(metadata_path, "r") as f:
with open(metadata_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
cached_hash = metadata.get("tools_info_hash", "")
@@ -215,7 +223,9 @@ def _load_cached_index() -> bool:
)
_cached_vector_store = FAISS.load_local(
index_path, embeddings, allow_dangerous_deserialization=True
index_path,
embeddings,
allow_dangerous_deserialization=True,
)
_cached_file_hash = cached_hash
@@ -244,8 +254,11 @@ def _save_cached_index():
_cached_vector_store.save_local(index_path)
# Save metadata
metadata = {"tools_info_hash": _cached_file_hash, "created_at": time.time()}
with open(metadata_path, "w") as f:
metadata = {
"tools_info_hash": _cached_file_hash,
"created_at": time.time(),
}
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f)
logging.info("Successfully saved vector index to cache")
@@ -261,16 +274,23 @@ def _build_vector_index():
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
tools_info = json.loads(f.read())
tool_descriptions = [f"{t['class_name']}: {t['class_desc']}" for t in tools_info]
tool_descriptions = [
f"{t['class_name']}: {t['class_desc']}" for t in tools_info
]
from langchain_community.embeddings import DashScopeEmbeddings
embeddings = DashScopeEmbeddings(
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v1"
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
model="text-embedding-v1",
)
metadatas = [{"index": i} for i in range(len(tool_descriptions))]
vector_store = FAISS.from_texts(tool_descriptions, embeddings, metadatas=metadatas)
vector_store = FAISS.from_texts(
tool_descriptions,
embeddings,
metadatas=metadatas,
)
# Cache the results
_cached_vector_store = vector_store
@@ -283,7 +303,7 @@ def _build_vector_index():
def retrieve_ops_vector(user_query, limit=20):
"""Tool retrieval using vector search with caching - returns list of tool names"""
"""Tool retrieval using vector search with caching"""
global _cached_vector_store
# Try to load from cache first
@@ -292,7 +312,10 @@ def retrieve_ops_vector(user_query, limit=20):
_build_vector_index()
# Perform similarity search
retrieved_tools = _cached_vector_store.similarity_search(user_query, k=limit)
retrieved_tools = _cached_vector_store.similarity_search(
user_query,
k=limit,
)
retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools]
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
@@ -307,7 +330,11 @@ def retrieve_ops_vector(user_query, limit=20):
return tool_names
async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") -> list:
async def retrieve_ops(
user_query: str,
limit: int = 20,
mode: str = "auto",
) -> list:
"""
Tool retrieval with configurable mode
@@ -322,59 +349,56 @@ async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") ->
Returns:
List of tool names
"""
if mode == "llm":
if mode in ("llm", "auto"):
try:
return await retrieve_ops_lm(user_query, limit=limit)
except Exception as e:
logging.error(f"LLM retrieval failed: {str(e)}")
return []
if mode != "auto":
return []
elif mode == "vector":
if mode in ("vector", "auto"):
try:
return retrieve_ops_vector(user_query, limit=limit)
except Exception as e:
logging.error(f"Vector retrieval failed: {str(e)}")
return []
elif mode == "auto":
try:
return await retrieve_ops_lm(user_query, limit=limit)
except Exception as e:
import traceback
print(traceback.format_exc())
try:
return retrieve_ops_vector(user_query, limit=limit)
except Exception as fallback_e:
logging.error(
f"Tool retrieval failed: {str(e)}, fallback retrieval also failed: {str(fallback_e)}"
)
return []
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'")
raise ValueError(
f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'",
)
if __name__ == "__main__":
import asyncio
user_query = (
"Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms."
+ "Additionally, filter out small images, perform image tagging, and remove duplicate images."
query = (
"Clean special characters from text and filter samples with "
+ "excessive length. Mask sensitive information and filter "
+ "unsafe content including adult/terror-related terms."
+ "Additionally, filter out small images, perform image "
+ "tagging, and remove duplicate images."
)
# Test different modes
print("=== Testing LLM mode ===")
tool_names_llm = asyncio.run(retrieve_ops(user_query, limit=10, mode="llm"))
tool_names_llm = asyncio.run(
retrieve_ops(query, limit=10, mode="llm"),
)
print("Retrieved tool names (LLM):")
print(tool_names_llm)
print("\n=== Testing Vector mode ===")
tool_names_vector = asyncio.run(retrieve_ops(user_query, limit=10, mode="vector"))
tool_names_vector = asyncio.run(
retrieve_ops(query, limit=10, mode="vector"),
)
print("Retrieved tool names (Vector):")
print(tool_names_vector)
print("\n=== Testing Auto mode (default) ===")
tool_names_auto = asyncio.run(retrieve_ops(user_query, limit=10, mode="auto"))
tool_names_auto = asyncio.run(
retrieve_ops(query, limit=10, mode="auto"),
)
print("Retrieved tool names (Auto):")
print(tool_names_auto)