init
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Configuration(BaseModel):
|
||||
"""The configuration for the agent."""
|
||||
|
||||
query_generator_model: str = Field(
|
||||
default="qwen-max-latest",
|
||||
metadata={
|
||||
"description": "The name of the language model to use for "
|
||||
"the agent's query generation.",
|
||||
},
|
||||
)
|
||||
query_generator_param: dict = Field(
|
||||
default={"temperature": 0.3, "stream": False},
|
||||
)
|
||||
|
||||
reflection_model: str = Field(
|
||||
default="qwen-plus-latest",
|
||||
metadata={
|
||||
"description": "The name of the language model to use for"
|
||||
" the agent's reflection.",
|
||||
},
|
||||
)
|
||||
reflection_param: dict = Field(
|
||||
default={"temperature": 0.3, "stream": False},
|
||||
)
|
||||
|
||||
answer_model: str = Field(
|
||||
default="qwen-plus-latest",
|
||||
metadata={
|
||||
"description": "The name of the language model to use "
|
||||
"for the agent's answer.",
|
||||
},
|
||||
)
|
||||
answer_param: dict = Field(default={"temperature": 0.3, "stream": False})
|
||||
|
||||
num_of_init_q: int = Field(
|
||||
default=3,
|
||||
metadata={
|
||||
"description": "The number of initial search queries to generate.",
|
||||
},
|
||||
)
|
||||
|
||||
max_research_loops: int = Field(
|
||||
default=2,
|
||||
metadata={
|
||||
"description": "The maximum number of research loops to perform.",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> "Configuration":
|
||||
"""Create a Configuration instance from a RunnableConfig."""
|
||||
configurable = (
|
||||
config["configurable"]
|
||||
if config and "configurable" in config
|
||||
else {}
|
||||
)
|
||||
|
||||
# Get raw values from environment or config
|
||||
raw_values: dict[str, Any] = {
|
||||
name: os.environ.get(name.upper(), configurable.get(name))
|
||||
for name in cls.model_fields.keys()
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
values = {k: v for k, v in raw_values.items() if v is not None}
|
||||
|
||||
return cls(**values)
|
||||
@@ -0,0 +1,123 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from base64 import b64encode
|
||||
from hashlib import sha256
|
||||
from hmac import new as hmac_new
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
from utils import format_time
|
||||
|
||||
|
||||
class CustomSearchTool:
|
||||
def __init__(self, search_engine: str = "quark"):
|
||||
assert search_engine in ["quark"]
|
||||
self.search_engine = search_engine
|
||||
|
||||
if self.search_engine == "quark":
|
||||
self.search_func = self._quark_search
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.search_engine = search_engine
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Execute search and return the results
|
||||
:param query:
|
||||
:param num_results:
|
||||
:return:
|
||||
"""
|
||||
|
||||
return self.search_func(query)
|
||||
|
||||
def search_quark_to_b_signature(self, user_name, timestamp, salt: str, sk):
|
||||
"""
|
||||
signature
|
||||
:param user_name: username
|
||||
:param timestamp: timestamp
|
||||
:param salt: salt
|
||||
:param sk:
|
||||
:return:
|
||||
"""
|
||||
data = f"{user_name}{timestamp}{salt}{sk}"
|
||||
hashed = hmac_new(sk.encode("utf-8"), data.encode("utf-8"), sha256)
|
||||
return b64encode(hashed.digest()).decode("utf-8")
|
||||
|
||||
def search_quark_to_b_gen_token(self, user_name: str, sk: str):
|
||||
"""
|
||||
get token
|
||||
:param user_name:
|
||||
:param sk:
|
||||
:return:
|
||||
"""
|
||||
timestamp = str(int(time.time() * 1000))
|
||||
salt = "".join(random.choice(string.ascii_lowercase) for _ in range(6))
|
||||
sign = self.search_quark_to_b_signature(user_name, timestamp, salt, sk)
|
||||
postBody = {
|
||||
"userName": user_name,
|
||||
"timestamp": timestamp,
|
||||
"salt": salt,
|
||||
"sign": sign,
|
||||
}
|
||||
url = "https://zx-dsc.sm.cn/api/auth/token"
|
||||
headers = {"content-type": "application/json"}
|
||||
response = requests.post(url, json=postBody, headers=headers)
|
||||
data = response.json()
|
||||
token = data["result"]["token"]
|
||||
return token
|
||||
|
||||
def _quark_search(self, query: str):
|
||||
ak = os.getenv("QUARK_AK", "")
|
||||
sk = os.getenv("QUARK_SK", "")
|
||||
token = self.search_quark_to_b_gen_token(ak, sk)
|
||||
url = "https://zx-dsc.sm.cn/api/resource/s_agg/ex/query"
|
||||
querystring = {
|
||||
"page": "1",
|
||||
"q": query,
|
||||
}
|
||||
request_id = str(uuid.uuid4())
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"request-id": request_id,
|
||||
}
|
||||
try:
|
||||
response = requests.get(url, headers=headers, params=querystring)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if (
|
||||
data.get("items", {}).get("@attributes", {}).get("status")
|
||||
== "OK"
|
||||
and data.get(
|
||||
"items",
|
||||
)
|
||||
and data.get("items", {}).get("item")
|
||||
):
|
||||
items = data.get("items").get("item")
|
||||
formated_items = []
|
||||
for item in items:
|
||||
formated_items.append(
|
||||
{
|
||||
"title": item["title"],
|
||||
"url": item["url"],
|
||||
"snippet": item["desc"],
|
||||
"content": item["MainBody"],
|
||||
"publish_date": format_time(item.get("time")),
|
||||
"site_name": item.get("site_name", ""),
|
||||
},
|
||||
)
|
||||
return formated_items
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Quark search failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,526 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope_runtime.engine.agents.langgraph_agent import LangGraphAgent
|
||||
from agentscope_runtime.engine.helpers.helper import simple_call_agent_direct
|
||||
from configuration import Configuration
|
||||
from custom_search_tool import CustomSearchTool
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from langgraph.types import Send
|
||||
from llm_prompts import (
|
||||
answer_instructions,
|
||||
query_writer_instructions,
|
||||
reflection_instructions,
|
||||
web_searcher_instructions,
|
||||
)
|
||||
from llm_utils import call_dashscope, extract_json_from_qwen
|
||||
from state import (
|
||||
OverallState,
|
||||
QueryGenerationState,
|
||||
ReflectionState,
|
||||
WebSearchState,
|
||||
)
|
||||
from utils import (
|
||||
custom_get_citations,
|
||||
custom_resolve_urls,
|
||||
get_current_date,
|
||||
get_research_topic,
|
||||
insert_citation_markers,
|
||||
)
|
||||
|
||||
load_dotenv("../.env")
|
||||
|
||||
if os.getenv("DASHSCOPE_API_KEY") is None:
|
||||
raise ValueError("DASHSCOPE_API_KEY is not set")
|
||||
|
||||
|
||||
def format_search_results(search_results: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Convert the search results
|
||||
:param search_results:
|
||||
:return:
|
||||
"""
|
||||
|
||||
formatted_results = []
|
||||
|
||||
for i, result in enumerate(search_results, 1):
|
||||
formatted_result = f"""
|
||||
Result Number {i}:
|
||||
Title: {result.get('title', 'N/A')}
|
||||
Label:{result.get('site_name', 'N/A')}
|
||||
URL: {result.get('url', 'N/A')}
|
||||
Snippet: {result.get('snippet', 'N/A')}
|
||||
publish_date: {result.get('publish_date', 'N/A')}
|
||||
---
|
||||
"""
|
||||
formatted_results.append(formatted_result)
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
|
||||
class WebSearchGraph:
|
||||
def __init__(
|
||||
self,
|
||||
config: RunnableConfig,
|
||||
call_llm_func,
|
||||
search_tool: CustomSearchTool,
|
||||
):
|
||||
self.configurable = Configuration.from_runnable_config(config)
|
||||
self.call_llm_func = call_llm_func
|
||||
self.search_tool = search_tool
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.total_tokens = 0
|
||||
self.max_retries = 3
|
||||
self.retry_delay = 2
|
||||
self.current_date = get_current_date()
|
||||
|
||||
def get_chat_completion(self, **args):
|
||||
completion = self.call_llm_func(**args)
|
||||
self.input_tokens += completion.usage.prompt_tokens
|
||||
self.output_tokens += completion.usage.completion_tokens
|
||||
self.total_tokens += completion.usage.total_tokens
|
||||
return completion.choices[0].message.content
|
||||
|
||||
def generate_query(self, state: OverallState) -> QueryGenerationState:
|
||||
"""LangGraph node that generates search queries
|
||||
based on the User's question.
|
||||
|
||||
Uses QWen Max to create optimized search queries
|
||||
for web research based on the User's question.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the User's question
|
||||
config: Configuration for the runnable,
|
||||
including LLM provider settings
|
||||
|
||||
Returns:
|
||||
Dictionary with state update,
|
||||
including search_query key containing the
|
||||
generated queries
|
||||
"""
|
||||
# check for custom initial search query count
|
||||
if state.get("initial_search_query_count") is None:
|
||||
state[
|
||||
"initial_search_query_count"
|
||||
] = self.configurable.num_of_init_q
|
||||
|
||||
# Format the prompt
|
||||
formatted_prompt = query_writer_instructions.format(
|
||||
current_date=self.current_date,
|
||||
research_topic=get_research_topic(state["messages"]),
|
||||
number_queries=state["initial_search_query_count"],
|
||||
)
|
||||
param = {
|
||||
"model": self.configurable.query_generator_model,
|
||||
"messages": [{"role": "user", "content": formatted_prompt}],
|
||||
**self.configurable.query_generator_param,
|
||||
}
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = self.get_chat_completion(**param)
|
||||
result = extract_json_from_qwen(result)
|
||||
result = json.loads(result)
|
||||
query = result.get("query")
|
||||
if isinstance(query, str):
|
||||
query = [query]
|
||||
assert isinstance(query, list)
|
||||
break
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error occurred when generating search query (attempt"
|
||||
f" {attempt + 1}/{self.max_retries}): {e}.",
|
||||
)
|
||||
if attempt == self.max_retries - 1: # Last attempt failed
|
||||
query = [get_research_topic(state["messages"])]
|
||||
break
|
||||
time.sleep(self.retry_delay)
|
||||
return {"search_query": query}
|
||||
|
||||
def continue_to_web_research(self, state: QueryGenerationState):
|
||||
"""LangGraph node that sends the
|
||||
search queries to the web research node.
|
||||
|
||||
This is used to spawn n number
|
||||
of web research nodes, one for each search query.
|
||||
"""
|
||||
return [
|
||||
Send(
|
||||
"web_research",
|
||||
{"search_query": search_query, "id": str(idx)},
|
||||
)
|
||||
for idx, search_query in enumerate(state["search_query"])
|
||||
]
|
||||
|
||||
def web_research(self, state: WebSearchState):
|
||||
"""LangGraph node that performs web research using the native Google
|
||||
Search API tool.
|
||||
|
||||
Executes a web search using the native Google Search API tool in
|
||||
combination with Gemini 2.0 Flash.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the
|
||||
search query and research loop count
|
||||
config: Configuration for the runnable,
|
||||
including search API settings
|
||||
|
||||
Returns:
|
||||
Dictionary with state update,
|
||||
including sources_gathered, research_loop_count,
|
||||
and web_research_results
|
||||
"""
|
||||
|
||||
search_results = self.search_tool.search(
|
||||
state["search_query"],
|
||||
)
|
||||
|
||||
search_context = format_search_results(search_results)
|
||||
|
||||
formatted_prompt = (
|
||||
web_searcher_instructions.format(
|
||||
current_date=self.current_date,
|
||||
research_topic=state["search_query"],
|
||||
)
|
||||
+ f"\n\nSearch Result:\n{search_context}"
|
||||
)
|
||||
|
||||
param = {
|
||||
"model": self.configurable.query_generator_model,
|
||||
"messages": [{"role": "user", "content": formatted_prompt}],
|
||||
**self.configurable.query_generator_param,
|
||||
}
|
||||
|
||||
sources_gathered = []
|
||||
for result in search_results:
|
||||
url = result.get("url")
|
||||
if url:
|
||||
sources_gathered.append(
|
||||
{
|
||||
"label": result.get("site_name"),
|
||||
"short_url": url,
|
||||
"value": url,
|
||||
},
|
||||
)
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = self.get_chat_completion(**param)
|
||||
resolved_urls = custom_resolve_urls(
|
||||
search_results,
|
||||
state["id"],
|
||||
)
|
||||
citations = custom_get_citations(search_results, resolved_urls)
|
||||
|
||||
modified_text = insert_citation_markers(result, citations)
|
||||
return {
|
||||
"sources_gathered": sources_gathered,
|
||||
"search_query": [state["search_query"]],
|
||||
"web_research_result": [modified_text],
|
||||
}
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error occurred when web search query: "
|
||||
f"`{state['search_query']}` "
|
||||
f"(attempt {attempt + 1}/{self.max_retries}): {e}.",
|
||||
)
|
||||
|
||||
summary = (
|
||||
f"{len(search_results)} related results are found "
|
||||
f"about search query '{state['search_query']}'"
|
||||
)
|
||||
if attempt == self.max_retries - 1:
|
||||
return {
|
||||
"sources_gathered": sources_gathered,
|
||||
"search_query": [state["search_query"]],
|
||||
"web_research_result": [summary],
|
||||
}
|
||||
time.sleep(self.retry_delay)
|
||||
return None
|
||||
|
||||
def reflection(self, state: OverallState) -> Optional[ReflectionState]:
|
||||
"""LangGraph node that identifies knowledge gaps and generates
|
||||
potential follow-up queries.
|
||||
|
||||
Analyzes the current summary to identify areas for further
|
||||
research and generates
|
||||
potential follow-up queries. Uses structured output to extract
|
||||
the follow-up query in JSON format.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the running summary
|
||||
and research topic
|
||||
config: Configuration for the runnable, including LLM
|
||||
provider settings
|
||||
|
||||
Returns:
|
||||
Dictionary with state update, including search_query key
|
||||
containing the generated follow-up query
|
||||
"""
|
||||
state["research_loop_count"] = state.get("research_loop_count", 0) + 1
|
||||
reasoning_model = self.configurable.reflection_model
|
||||
|
||||
# Format the prompt
|
||||
formatted_prompt = reflection_instructions.format(
|
||||
current_date=self.current_date,
|
||||
research_topic=get_research_topic(state["messages"]),
|
||||
summaries="\n\n---\n\n".join(state["web_research_result"]),
|
||||
)
|
||||
param = {
|
||||
"model": reasoning_model,
|
||||
"messages": [{"role": "user", "content": formatted_prompt}],
|
||||
**self.configurable.reflection_param,
|
||||
}
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = self.get_chat_completion(**param)
|
||||
result = extract_json_from_qwen(result)
|
||||
result = json.loads(result)
|
||||
is_sufficient = result.get("is_sufficient", True)
|
||||
knowledge_gap = result.get("knowledge_gap", "")
|
||||
follow_up_queries = result.get("follow_up_queries", [])
|
||||
assert isinstance(follow_up_queries, list)
|
||||
return {
|
||||
"is_sufficient": is_sufficient,
|
||||
"knowledge_gap": knowledge_gap,
|
||||
"follow_up_queries": follow_up_queries,
|
||||
"research_loop_count": state["research_loop_count"],
|
||||
"number_of_ran_queries": len(state["search_query"]),
|
||||
}
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error occurred when reflection (attempt {attempt + 1}"
|
||||
f"/{self.max_retries}): {e}.",
|
||||
)
|
||||
if attempt == self.max_retries - 1: # Last attempt failed
|
||||
return {
|
||||
"is_sufficient": True,
|
||||
"knowledge_gap": "",
|
||||
"follow_up_queries": [],
|
||||
"research_loop_count": state["research_loop_count"],
|
||||
"number_of_ran_queries": len(state["search_query"]),
|
||||
}
|
||||
time.sleep(self.retry_delay)
|
||||
return None
|
||||
|
||||
def evaluate_research(
|
||||
self,
|
||||
state: ReflectionState,
|
||||
config: RunnableConfig,
|
||||
):
|
||||
"""LangGraph routing function that determines the next step in the
|
||||
research flow.
|
||||
|
||||
Controls the research loop by deciding whether to continue gathering
|
||||
information
|
||||
or to finalize the summary based on the configured maximum number of
|
||||
research loops.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the research loop count
|
||||
config: Configuration for the runnable, including
|
||||
max_research_loops setting
|
||||
|
||||
Returns:
|
||||
String literal indicating the next node to visit ("web_research"
|
||||
or "finalize_summary")
|
||||
"""
|
||||
configurable = Configuration.from_runnable_config(config)
|
||||
max_research_loops = (
|
||||
state.get("max_research_loops")
|
||||
if state.get("max_research_loops") is not None
|
||||
else configurable.max_research_loops
|
||||
)
|
||||
if (
|
||||
state["is_sufficient"]
|
||||
or state["research_loop_count"] >= max_research_loops
|
||||
):
|
||||
return "finalize_answer"
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"web_research",
|
||||
{
|
||||
"search_query": follow_up_query,
|
||||
"id": state["number_of_ran_queries"] + int(idx),
|
||||
},
|
||||
)
|
||||
for idx, follow_up_query in enumerate(
|
||||
state["follow_up_queries"],
|
||||
)
|
||||
]
|
||||
|
||||
def finalize_answer(self, state: OverallState):
|
||||
"""LangGraph node that finalizes the research summary.
|
||||
|
||||
Prepares the final output by deduplicating and formatting sources, then
|
||||
combining them with the running summary to create a well-structured
|
||||
research report with proper citations.
|
||||
|
||||
Args:
|
||||
state: Current graph state containing the running summary
|
||||
and sources gathered
|
||||
|
||||
Returns:
|
||||
Dictionary with state update, including running_summary
|
||||
key containing
|
||||
the formatted final summary with sources
|
||||
"""
|
||||
answer_model = self.configurable.answer_model
|
||||
formatted_prompt = answer_instructions.format(
|
||||
current_date=self.current_date,
|
||||
research_topic=get_research_topic(state["messages"]),
|
||||
summaries="\n---\n\n".join(state["web_research_result"]),
|
||||
)
|
||||
|
||||
param = {
|
||||
"model": answer_model,
|
||||
"messages": [{"role": "user", "content": formatted_prompt}],
|
||||
**self.configurable.answer_param,
|
||||
}
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
result = self.get_chat_completion(**param)
|
||||
|
||||
unique_sources = []
|
||||
for source in state["sources_gathered"]:
|
||||
if source["short_url"] in result:
|
||||
result = result.replace(
|
||||
source["short_url"],
|
||||
source["value"],
|
||||
)
|
||||
unique_sources.append(source)
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=result)],
|
||||
"sources_gathered": unique_sources,
|
||||
}
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error occurred when generating answer (attempt "
|
||||
f"{attempt + 1}/{self.max_retries}): {e}.",
|
||||
)
|
||||
if attempt == self.max_retries - 1:
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content=f"Error occurred"
|
||||
f" when generating answer. {e}",
|
||||
),
|
||||
],
|
||||
"sources_gathered": [],
|
||||
}
|
||||
time.sleep(self.retry_delay)
|
||||
return None
|
||||
|
||||
async def run(self, user_question: str):
|
||||
# Create our Agent Graph
|
||||
builder = StateGraph(OverallState, config_schema=Configuration)
|
||||
|
||||
# Define the nodes we will cycle between
|
||||
builder.add_node("generate_query", self.generate_query)
|
||||
builder.add_node("web_research", self.web_research)
|
||||
builder.add_node("reflection", self.reflection)
|
||||
builder.add_node("finalize_answer", self.finalize_answer)
|
||||
|
||||
# Set the entrypoint as `generate_query`
|
||||
# This means that this node is the first one called
|
||||
builder.add_edge(START, "generate_query")
|
||||
# Add conditional edge to continue with search queries in a
|
||||
# parallel branch
|
||||
builder.add_conditional_edges(
|
||||
"generate_query",
|
||||
self.continue_to_web_research,
|
||||
["web_research"],
|
||||
)
|
||||
# Reflect on the web research
|
||||
builder.add_edge("web_research", "reflection")
|
||||
# Evaluate the research
|
||||
builder.add_conditional_edges(
|
||||
"reflection",
|
||||
self.evaluate_research,
|
||||
["web_research", "finalize_answer"],
|
||||
)
|
||||
# Finalize the answer
|
||||
builder.add_edge("finalize_answer", END)
|
||||
compiled_graph = builder.compile(name="pro-search-agent")
|
||||
|
||||
def human_ai_message_to_dict(obj):
|
||||
if isinstance(obj, HumanMessage):
|
||||
return {
|
||||
"sender": obj.type,
|
||||
"content": obj.content,
|
||||
}
|
||||
if isinstance(obj, AIMessage):
|
||||
return {
|
||||
"sender": obj.type,
|
||||
"content": obj.content,
|
||||
}
|
||||
raise TypeError(
|
||||
f"Object of type {obj.__class__.__name__} is"
|
||||
f" not JSON serializable",
|
||||
)
|
||||
|
||||
def state_folder(messages):
|
||||
if len(messages) > 0:
|
||||
return json.loads(messages[0]["content"])
|
||||
else:
|
||||
return []
|
||||
|
||||
def state_unfolder(state):
|
||||
state_jsons = json.dumps(state, default=human_ai_message_to_dict)
|
||||
return state_jsons
|
||||
|
||||
langgraph_agent = LangGraphAgent(
|
||||
compiled_graph,
|
||||
state_folder,
|
||||
state_unfolder,
|
||||
)
|
||||
|
||||
input_state = {
|
||||
"messages": [{"role": "user", "content": user_question}],
|
||||
"max_research_loops": self.configurable.max_research_loops,
|
||||
"initial_search_query_count": self.configurable.num_of_init_q,
|
||||
}
|
||||
input_json = json.dumps(input_state)
|
||||
all_result = await simple_call_agent_direct(
|
||||
langgraph_agent,
|
||||
input_json,
|
||||
)
|
||||
|
||||
state = json.loads(all_result)
|
||||
return state["messages"][-1]["content"]
|
||||
|
||||
|
||||
async def main():
|
||||
custom_search_tool = CustomSearchTool(search_engine="quark")
|
||||
|
||||
graph = WebSearchGraph(
|
||||
json.loads(Configuration().model_dump_json()),
|
||||
call_dashscope,
|
||||
custom_search_tool,
|
||||
)
|
||||
|
||||
print(
|
||||
"""Type in your question or q to quit.""",
|
||||
)
|
||||
|
||||
user_input = input(">").strip()
|
||||
while user_input != "q":
|
||||
question = user_input
|
||||
item = await graph.run(question)
|
||||
print(item, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
user_input = input(">")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,134 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
query_writer_instructions = """Your goal is to generate sophisticated and
|
||||
diverse web search queries.
|
||||
These queries are intended for an advanced automated web research tool capable
|
||||
of analyzing complex results,
|
||||
following links, and synthesizing information.
|
||||
|
||||
Instructions:
|
||||
- Always prefer a single search query, only add another query if the original
|
||||
question requests multiple aspects or elements and one query is not enough.
|
||||
- Each query should focus on one specific aspect of the original question.
|
||||
- Don't produce more than {number_queries} queries.
|
||||
- Queries should be diverse, if the topic is broad, generate more than 1 query.
|
||||
- Don't generate multiple similar queries, 1 is enough.
|
||||
- Query should ensure that the most current information is gathered. The
|
||||
current date is {current_date}.
|
||||
|
||||
Format:
|
||||
- Format your response as a JSON object with ALL three of these exact keys:
|
||||
- "rationale": str, A brief explanation of why these queries are relevant
|
||||
to the research topic.
|
||||
- "query": list[str], A list of search queries to be used for web research.
|
||||
|
||||
Example:
|
||||
|
||||
Topic: What revenue grew more last year apple stock or the number of people
|
||||
buying an iphone
|
||||
```json
|
||||
{{
|
||||
"rationale": "To answer this comparative growth question accurately,
|
||||
we need specific data points on Apple's stock performance and iPhone
|
||||
sales metrics. These queries target the precise financial information
|
||||
needed: company revenue trends, product-specific unitsales figures,
|
||||
and stock price movement over the same fiscal period for
|
||||
direct comparison.",
|
||||
"query": ["Apple total revenue growth fiscal year 2024", "iPhone unit sales
|
||||
growth fiscal
|
||||
year 2024", "Apple stock price growth fiscal year 2024"],
|
||||
}}
|
||||
```
|
||||
|
||||
Context: {research_topic}"""
|
||||
|
||||
|
||||
web_searcher_instructions = """Conduct targeted Google Searches to gather the
|
||||
most recent, credible
|
||||
information on "{research_topic}" and synthesize it into a verifiable text
|
||||
artifact.
|
||||
|
||||
Instructions:
|
||||
- Query should ensure that the most current information is gathered. The
|
||||
current date is {current_date}.
|
||||
- Conduct multiple, diverse searches to gather comprehensive information.
|
||||
- Consolidate key findings while meticulously tracking the source(s) for each
|
||||
specific piece of information.
|
||||
- The output should be a well-written summary or report based on your search
|
||||
findings.
|
||||
- Only include the information found in the search results, don't make up any
|
||||
information.
|
||||
|
||||
Research Topic:
|
||||
{research_topic}
|
||||
"""
|
||||
|
||||
reflection_instructions = """You are an expert research assistant analyzing
|
||||
summaries about "{research_topic}".
|
||||
|
||||
Instructions:
|
||||
- Identify knowledge gaps or areas that need deeper exploration and generate a
|
||||
follow-up query. (1 or multiple).
|
||||
- If provided summaries are sufficient to answer the user's question, don't
|
||||
generate a follow-up query.
|
||||
- If there is a knowledge gap, generate a follow-up query that would help
|
||||
expand your understanding.
|
||||
- Focus on technical details, implementation specifics, or emerging trends
|
||||
that weren't fully covered.
|
||||
|
||||
Requirements:
|
||||
- Ensure the follow-up query is self-contained and includes necessary context
|
||||
for web search.
|
||||
|
||||
Output Format:
|
||||
- Format your response as a JSON object with these exact keys:
|
||||
- "is_sufficient": true or false. Whether the provided summaries are
|
||||
sufficient to answer the user's question.
|
||||
- "knowledge_gap": str, A description of what information is missing or
|
||||
needs clarification.
|
||||
- "follow_up_queries": list, A list of follow-up queries to address the
|
||||
knowledge gap.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{{
|
||||
"is_sufficient": true, // or false
|
||||
"knowledge_gap": "The summary lacks information about performance metrics
|
||||
and benchmarks", //
|
||||
"" if is_sufficient is true
|
||||
"follow_up_queries": ["What are typical performance benchmarks and metrics
|
||||
used to evaluate
|
||||
[specific technology]?"]
|
||||
// [] if is_sufficient is true
|
||||
}}
|
||||
```
|
||||
|
||||
Reflect carefully on the Summaries to identify knowledge gaps and produce a
|
||||
follow-up query.
|
||||
Then, produce your output following this JSON format:
|
||||
|
||||
Summaries:
|
||||
{summaries}
|
||||
"""
|
||||
|
||||
answer_instructions = """Generate a high-quality answer to the user's question
|
||||
based on the provided summaries.
|
||||
|
||||
Instructions:
|
||||
- The current date is {current_date}.
|
||||
- You are the final step of a multi-step research process, don't mention that
|
||||
you are the final step.
|
||||
- You have access to all the information gathered from the previous steps.
|
||||
- You have access to the user's question.
|
||||
- Generate a high-quality answer to the user's question based on the provided
|
||||
summaries
|
||||
and the user's question.
|
||||
- Include the sources you used from the Summaries in the answer correctly,
|
||||
use markdown format. THIS IS A MUST.
|
||||
|
||||
User Context:
|
||||
- {research_topic}
|
||||
|
||||
Summaries:
|
||||
{summaries}"""
|
||||
@@ -0,0 +1,177 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
|
||||
def extract_json_from_qwen(qwen_result: str) -> str:
|
||||
sql = ""
|
||||
pattern = r"```json(.*?)```"
|
||||
|
||||
sql_code_snippets = re.findall(pattern, qwen_result, re.DOTALL)
|
||||
|
||||
if len(sql_code_snippets) > 0:
|
||||
sql = sql_code_snippets[-1].strip()
|
||||
|
||||
return sql
|
||||
|
||||
|
||||
def call_dashscope(**args: Any) -> ChatCompletion:
|
||||
client = OpenAI(
|
||||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
**args,
|
||||
)
|
||||
stream = args.get("stream", False)
|
||||
if stream:
|
||||
try:
|
||||
completion = postprocess_completion(completion)
|
||||
return completion
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error occurred when postprocess_completion on "
|
||||
f"'stream=True'. {e}",
|
||||
)
|
||||
default_message = ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content="Error in calling LLM", # 默认内容
|
||||
)
|
||||
default_choice = Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
message=default_message,
|
||||
)
|
||||
default_chat_completion = ChatCompletion(
|
||||
id="chatcmpl-1234567890",
|
||||
choices=[default_choice],
|
||||
created=int(datetime.now().timestamp()),
|
||||
model=args["model"],
|
||||
object="chat.completion",
|
||||
service_tier="default",
|
||||
system_fingerprint=None,
|
||||
usage=None,
|
||||
)
|
||||
return default_chat_completion
|
||||
return completion
|
||||
|
||||
|
||||
def merge_fields(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, str):
|
||||
target[key] = target.get(key, "") + value
|
||||
elif value is not None and isinstance(value, dict):
|
||||
merge_fields(target[key], value)
|
||||
|
||||
|
||||
def merge_chunk(final_response: Dict[str, Any], delta: Dict[str, Any]) -> None:
|
||||
delta.pop("role", None)
|
||||
merge_fields(final_response, delta)
|
||||
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if tool_calls and len(tool_calls) > 0:
|
||||
index = int(tool_calls[0].pop("index")) # Convert index to integer
|
||||
if "tool_calls" not in final_response:
|
||||
final_response["tool_calls"] = {}
|
||||
final_response["tool_calls"][index] = final_response["tool_calls"].get(
|
||||
index,
|
||||
{},
|
||||
)
|
||||
final_response["tool_calls"][index].pop("type", None)
|
||||
merge_fields(final_response["tool_calls"][index], tool_calls[0])
|
||||
|
||||
|
||||
def postprocess_completion(completion: Iterator) -> ChatCompletion:
|
||||
message: Dict[str, Any] = {
|
||||
"content": "",
|
||||
"role": "assistant",
|
||||
"function_call": None,
|
||||
"tool_calls": defaultdict(
|
||||
lambda: {
|
||||
"function": {"arguments": "", "name": ""},
|
||||
"id": "",
|
||||
"type": "",
|
||||
},
|
||||
),
|
||||
"reasoning_content": "",
|
||||
"refusal": "",
|
||||
}
|
||||
last_chunk: Optional[Any] = None
|
||||
|
||||
for chunk in completion:
|
||||
try:
|
||||
delta = json.loads(chunk.choices[0].delta.json())
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding JSON from chunk: {e}")
|
||||
continue
|
||||
delta.pop("role", None)
|
||||
merge_chunk(message, delta)
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
logprobs = chunk.choices[0].logprobs
|
||||
last_chunk = chunk
|
||||
|
||||
# 显式声明类型
|
||||
tool_calls_list: List[Dict[str, Any]] = list(
|
||||
message.get("tool_calls", {}).values(),
|
||||
)
|
||||
message["tool_calls"] = tool_calls_list
|
||||
|
||||
tool_calls = None
|
||||
if message["tool_calls"]:
|
||||
tool_calls = []
|
||||
for tool_call in message["tool_calls"]: # 类型已明确为 Dict
|
||||
function = Function(
|
||||
arguments=tool_call["function"]["arguments"],
|
||||
name=tool_call["function"]["name"],
|
||||
)
|
||||
tool_call_object = ChatCompletionMessageToolCall(
|
||||
id=tool_call["id"],
|
||||
function=function,
|
||||
type=tool_call["type"],
|
||||
)
|
||||
tool_calls.append(tool_call_object)
|
||||
|
||||
chat_message = ChatCompletionMessage(
|
||||
content=message["content"],
|
||||
role=message["role"],
|
||||
function_call=message["function_call"],
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=message["reasoning_content"],
|
||||
refusal=message["refusal"],
|
||||
)
|
||||
choices = [
|
||||
Choice(
|
||||
finish_reason=finish_reason,
|
||||
index=0,
|
||||
message=chat_message,
|
||||
logprobs=logprobs,
|
||||
),
|
||||
]
|
||||
|
||||
completion = ChatCompletion(
|
||||
id=last_chunk.id,
|
||||
choices=choices,
|
||||
created=last_chunk.created,
|
||||
model=last_chunk.model,
|
||||
object="chat.completion",
|
||||
service_tier=last_chunk.service_tier,
|
||||
system_fingerprint=last_chunk.system_fingerprint,
|
||||
usage=last_chunk.usage,
|
||||
)
|
||||
return completion
|
||||
@@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from qwen_langgraph_search.src.configuration import Configuration
|
||||
from qwen_langgraph_search.src.custom_search_tool import CustomSearchTool
|
||||
from qwen_langgraph_search.src.graph_openai_compatible import WebSearchGraph
|
||||
from qwen_langgraph_search.src.llm_utils import call_dashscope
|
||||
|
||||
if __name__ == "__main__":
|
||||
custom_search_tool = CustomSearchTool(search_engine="quark")
|
||||
graph = WebSearchGraph(
|
||||
json.loads(Configuration().model_dump_json()),
|
||||
call_dashscope,
|
||||
custom_search_tool,
|
||||
)
|
||||
|
||||
user_input = input("Type in your question or press q to quit\n")
|
||||
while user_input != "q":
|
||||
question = user_input
|
||||
use_agentengine = True
|
||||
|
||||
try:
|
||||
res = asyncio.run(graph.run(question))
|
||||
print(res)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
|
||||
user_input = input("Type in your question or press q to quit\n")
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from langgraph.graph import add_messages
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class OverallState(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
search_query: Annotated[list, operator.add]
|
||||
web_research_result: Annotated[list, operator.add]
|
||||
sources_gathered: Annotated[list, operator.add]
|
||||
initial_search_query_count: int
|
||||
max_research_loops: int
|
||||
research_loop_count: int
|
||||
reasoning_model: str
|
||||
|
||||
|
||||
class ReflectionState(TypedDict):
|
||||
is_sufficient: bool
|
||||
knowledge_gap: str
|
||||
follow_up_queries: Annotated[list, operator.add]
|
||||
research_loop_count: int
|
||||
number_of_ran_queries: int
|
||||
|
||||
|
||||
class Query(TypedDict):
|
||||
query: str
|
||||
rationale: str
|
||||
|
||||
|
||||
class QueryGenerationState(TypedDict):
|
||||
search_query: list[Query]
|
||||
|
||||
|
||||
class WebSearchState(TypedDict):
|
||||
search_query: str
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SearchStateOutput:
|
||||
running_summary: Optional[str] = field(default=None) # Final report
|
||||
@@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SearchQueryList(BaseModel):
|
||||
query: List[str] = Field(
|
||||
description="A list of search queries to be used for web research.",
|
||||
)
|
||||
rationale: str = Field(
|
||||
description="A brief explanation of why these queries are relevant "
|
||||
"to the research topic.",
|
||||
)
|
||||
|
||||
|
||||
class Reflection(BaseModel):
|
||||
is_sufficient: bool = Field(
|
||||
description="Whether the provided summaries are sufficient to answer "
|
||||
"the user's question.",
|
||||
)
|
||||
knowledge_gap: str = Field(
|
||||
description="A description of what information is missing or needs "
|
||||
"clarification.",
|
||||
)
|
||||
follow_up_queries: List[str] = Field(
|
||||
description="A list of follow-up queries to address the knowledge "
|
||||
"gap.",
|
||||
)
|
||||
@@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
||||
|
||||
|
||||
def get_current_date() -> str:
|
||||
return datetime.now().strftime("%B %d, %Y")
|
||||
|
||||
|
||||
def format_time(timestamp_param: str, format_str: str = "%Y-%m-%d") -> str:
|
||||
if not timestamp_param or not timestamp_param.isnumeric():
|
||||
return ""
|
||||
|
||||
try:
|
||||
timestamp = int(timestamp_param)
|
||||
return time.strftime(format_str, time.localtime(timestamp))
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return ""
|
||||
|
||||
|
||||
def get_research_topic(messages: List[AnyMessage]) -> str:
|
||||
"""
|
||||
Get the research topic from the messages.
|
||||
"""
|
||||
# check if request has a history and combine the messages
|
||||
# into a single string
|
||||
if len(messages) == 1:
|
||||
research_topic = messages[-1].content
|
||||
else:
|
||||
research_topic = ""
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
research_topic += f"User: {message.content}\n"
|
||||
elif isinstance(message, AIMessage):
|
||||
research_topic += f"Assistant: {message.content}\n"
|
||||
return research_topic
|
||||
|
||||
|
||||
def insert_citation_markers(text: str, citations_list: List[Dict]) -> str:
|
||||
"""
|
||||
Inserts citation markers into a text string based on start and end indices.
|
||||
|
||||
Args:
|
||||
text (str): The original text string.
|
||||
citations_list (list): A list of dictionaries, where each dictionary
|
||||
contains 'start_index', 'end_index', and
|
||||
'segment_string' (the marker to insert).
|
||||
Indices are assumed to be for the original text.
|
||||
|
||||
Returns:
|
||||
str: The text with citation markers inserted.
|
||||
"""
|
||||
# Sort citations by end_index in descending order.
|
||||
# If end_index is the same, secondary sort by start_index descending.
|
||||
# This ensures that insertions at the end of the string don't affect
|
||||
# the indices of earlier parts of the string that still
|
||||
# need to be processed.
|
||||
sorted_citations = sorted(
|
||||
citations_list,
|
||||
key=lambda c: (c["end_index"], c["start_index"]),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
modified_text = text
|
||||
for citation_info in sorted_citations:
|
||||
# These indices refer to positions in the *original* text,
|
||||
# but since we iterate from the end, they remain valid for insertion
|
||||
# relative to the parts of the string already processed.
|
||||
end_idx = citation_info["end_index"]
|
||||
marker_to_insert = ""
|
||||
for segment in citation_info["segments"]:
|
||||
marker_to_insert += (
|
||||
f" [{segment['label']}]({segment['short_url']})"
|
||||
)
|
||||
# Insert the citation marker at the original end_idx position
|
||||
modified_text = (
|
||||
modified_text[:end_idx]
|
||||
+ marker_to_insert
|
||||
+ modified_text[end_idx:]
|
||||
)
|
||||
|
||||
return modified_text
|
||||
|
||||
|
||||
def custom_resolve_urls(
|
||||
search_results: List[Dict[str, Any]],
|
||||
uid: str,
|
||||
) -> Dict[str, str]:
|
||||
prefix = "https://search-result.local/id/"
|
||||
resolved_map = {}
|
||||
|
||||
for idx, result in enumerate(search_results):
|
||||
url = result.get("url", "")
|
||||
if url and url not in resolved_map:
|
||||
resolved_map[url] = f"{prefix}{uid}-{idx}"
|
||||
|
||||
return resolved_map
|
||||
|
||||
|
||||
def custom_get_citations(
|
||||
search_results: List[Dict[str, Any]],
|
||||
resolved_urls_map: Dict[str, str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
citations = []
|
||||
|
||||
for idx, result in enumerate(search_results):
|
||||
url = result.get("url", "")
|
||||
title = result.get("title", f"搜索结果 {idx + 1}")
|
||||
|
||||
if url:
|
||||
citation = {
|
||||
"start_index": 0, # 简化处理,实际应用中可以更精确
|
||||
"end_index": len(title),
|
||||
"segments": [
|
||||
{
|
||||
"label": title[:50] + "..."
|
||||
if len(title) > 50
|
||||
else title,
|
||||
"short_url": resolved_urls_map.get(url, url),
|
||||
"value": url,
|
||||
},
|
||||
],
|
||||
}
|
||||
citations.append(citation)
|
||||
|
||||
return citations
|
||||
Reference in New Issue
Block a user