Files
evotraders/tuner/email_search/main.py
2026-01-19 12:25:13 +08:00

380 lines
12 KiB
Python

# -*- coding: utf-8 -*-
"""Example of training an Email Search agent with Trinity-RFT."""
import os
from typing import Dict
from _email_search_agent import EmailSearchAgent
from _utils import ( # pylint: disable=E0611
AnswerModel,
FinalRubric,
QueryModel,
)
from agentscope import logger
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.tuner import (
TunerModelConfig,
DatasetConfig,
JudgeOutput,
WorkflowOutput,
AlgorithmConfig,
tune,
)
from agentscope.model import ChatModelBase
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
and a list of tools you can use to search the user's email. Use the tools to
search the user's emails and find the answer to the user's query. You may take
up to {max_turns} turns to find the answer, so if your first seach doesn't
find the answer, you can try with different keywords.
Always describe what you see and plan your next steps clearly. When taking
actions, explain what you're doing and why. When the answer to the task is
found, call `generate_response` to finish the process. Only call
`generate_response` when answer is found. You should not respond any next steps
in `generate_response`. Complete all steps and then call `generate_response`.
User's email address is {inbox_address}
Today's date is {query_date}
"""
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput: # noqa: PLR0915
"""A workflow function using the Email Search agent to solve tasks.
Args:
task (Dict): The task to be solved.
Should contain fields from QueryModel.
model (TrinityChatModel): The language model to use.
Returns:
WorkflowOutput: The output containing the agent's response.
"""
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
# Parse task data
query = QueryModel.model_validate(task)
question = task.get("question", task.get("task_desc", ""))
# Get workflow arguments with defaults
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Format system prompt
system_prompt = SYSTEM_PROMPT.format(
max_turns=max_turns,
inbox_address=query.inbox_address,
query_date=query.query_date,
)
# Create EmailSearchAgent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
max_iters=max_turns,
)
# Reset agent state for a new rollout
await agent.reset()
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
# Extract answer and sources from response metadata
answer_and_sources = response.metadata or {}
if not answer_and_sources:
# Fallback: try to parse from content
answer_and_sources = {
"answer": response.get_text_content() or "",
"sources": [],
}
# Store agent state for judge function
# We'll pass this through the response metadata
response_metadata = {
"answer_and_sources": answer_and_sources,
"query": query.model_dump(),
"message_id_list": agent.message_id_list,
"ever_read_message_ids": agent.ever_read_message_ids,
# Estimate actual_turns from memory length
"actual_turns": (
max(1, (len(agent.memory.content) - 1) // 2)
if len(agent.memory.content) > 1
else 1
),
}
# Update response metadata
if response.metadata is None:
response.metadata = {}
response.metadata.update(response_metadata)
return WorkflowOutput(
response=response,
)
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
"""Calculate partial rewards based on rubric."""
partial_rewards = 0.0
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
partial_rewards += 0.1 if rubric.sources_correct else 0
return partial_rewards
def _calculate_correct_answer_reward(
rubric: FinalRubric,
max_turns: int,
) -> float:
"""Calculate reward for correct answers."""
reward = 1.0
reward += 0.3 if rubric.sources_correct else 0
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
reward += 0.1 * (1 - rubric.num_turns / max_turns)
return reward
def _initialize_rubric(
answer: str,
sources: list[str],
actual_turns: int,
query: QueryModel,
message_id_list: list[str],
ever_read_message_ids: list[str],
) -> FinalRubric:
"""Initialize and populate rubric with basic information."""
rubric = FinalRubric()
rubric.attempted_answer = answer is not None and answer != ""
rubric.returned_i_dont_know = answer == "I don't know"
rubric.num_sources = len(sources)
rubric.num_turns = actual_turns
if len(query.message_ids) > 0:
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
rubric.ever_read_right_email = (
query.message_ids[0] in ever_read_message_ids
)
rubric.sources_correct = query.message_ids[0] in sources
return rubric
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
"""A judge function to calculate reward based on agent's response.
Args:
task (Dict): The task information for the corresponding workflow.
response (Msg): The response generated by the corresponding workflow.
auxiliary_models (Dict[str, ChatModelBase]):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding ChatModelBase instances.
Returns:
JudgeOutput: The reward value assigned by the judge function.
"""
# Extract metadata from response
metadata = response.metadata or {}
answer_and_sources = metadata.get("answer_and_sources", {})
query_dict = metadata.get("query", {})
message_id_list = metadata.get("message_id_list", [])
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
actual_turns = metadata.get("actual_turns", 0)
# Parse query model
if not query_dict:
query_dict = task
query = QueryModel.model_validate(query_dict)
# Get arguments
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Extract answer and sources
try:
answer = answer_and_sources.get("answer", None)
sources = answer_and_sources.get("sources", [])
except Exception:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
if answer is None:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
# Initialize rubric
rubric = _initialize_rubric(
answer,
sources,
actual_turns,
query,
message_id_list,
ever_read_message_ids,
)
# Judge correctness using LLM-as-a-Judge
try:
judge_model = (
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
if auxiliary_models
else None
)
judge_response = await judge_correctness(
answer,
query,
judge_model,
)
rubric.answer_correct = judge_response
except Exception as e:
logger.error("Error judging correctness: %s", e)
rubric.answer_correct = False
# Calculate rewards
partial_rewards = _calculate_partial_rewards(rubric)
if rubric.attempted_answer and not rubric.answer_correct:
result = {"accuracy": -1.0, "format": partial_rewards}
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
result = {"accuracy": 0.0, "format": partial_rewards}
elif rubric.answer_correct:
reward = _calculate_correct_answer_reward(rubric, max_turns)
result = {"accuracy": 1.0, "format": reward}
else:
result = {"accuracy": 0.0, "format": 0.0}
metrics = result.copy()
metrics.update({"actual_turns": actual_turns})
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
# LLM-as-a-judge
async def judge_correctness(
answer: str,
query: QueryModel,
judge: ChatModelBase,
) -> bool:
"""Use an LLM to decide whether *answer* matches *query.answer*.
Returns a boolean *accept* flag used for scoring.
"""
system_prompt = """You are given a question, the reference answer
(labelled **Reference answer**), and an answer generated by an AI assistant
(labelled **AI answer**).
Follow these steps to decide whether the AI answer should be accepted:
1. Identify EXACTLY what information the **question** is asking for
(e.g. who, what, when, where, why, how, quantity, etc.).
2. From the **Reference answer**, extract ONLY the facts that are required
to directly satisfy the information need identified in step 1. Treat all
other facts as non-essential context.
3. Verify that every essential fact from step 2 appears in the **AI answer**
with the same meaning. Differences in wording, order, or additional
non-conflicting details are allowed.
4. If any essential fact is missing or contradicted in the **AI answer**,
then *accept* must be **false**. Otherwise *accept* must be **true**.
Important: Do NOT penalise the **AI answer** for omitting non-essential
facts that appear in the **Reference answer**. The answer should only be
rejected for errors or omissions in the information explicitly requested by
the question.
Return your judgement **accept** from **true** and **false**. Do not return
any other text or formatting.
"""
prompt = (
f"Question: {query.question}\n"
f"Reference answer: {query.answer}\n"
f"AI answer: {answer}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
chat_response = await judge(messages)
# Extract text content from ChatResponse
result_parts = []
for block in chat_response.content:
if isinstance(block, dict) and block.get("type") == "text":
result_parts.append(str(block.get("text", "")))
result = "".join(result_parts)
logger.info("LLM judge response: %s", result)
return "true" in result.lower()
# End of LLM-as-a-judge
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/enron_emails_dataset",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-4B-Instruct-2507",
max_model_len=20480,
max_tokens=4096,
inference_engine_num=4,
reasoning_parser=None,
)
aux_models = {
"judge": TunerModelConfig(
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
max_model_len=2500,
max_tokens=2048,
inference_engine_num=1,
tensor_parallel_size=2,
tool_call_parser=None,
reasoning_parser=None,
),
}
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
batch_size=64,
learning_rate=1e-6,
)
tune(
workflow_func=run_email_search_agent,
judge_func=email_search_judge,
train_dataset=dataset,
model=tuner_model,
auxiliary_models=aux_models,
algorithm=algorithm,
config_path=config_path,
)