380 lines
12 KiB
Python
380 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Example of training an Email Search agent with Trinity-RFT."""
|
|
import os
|
|
from typing import Dict
|
|
from _email_search_agent import EmailSearchAgent
|
|
from _utils import ( # pylint: disable=E0611
|
|
AnswerModel,
|
|
FinalRubric,
|
|
QueryModel,
|
|
)
|
|
from agentscope import logger
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
from agentscope.message import Msg
|
|
from agentscope.tuner import (
|
|
TunerModelConfig,
|
|
DatasetConfig,
|
|
JudgeOutput,
|
|
WorkflowOutput,
|
|
AlgorithmConfig,
|
|
tune,
|
|
)
|
|
from agentscope.model import ChatModelBase
|
|
|
|
|
|
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
|
|
and a list of tools you can use to search the user's email. Use the tools to
|
|
search the user's emails and find the answer to the user's query. You may take
|
|
up to {max_turns} turns to find the answer, so if your first seach doesn't
|
|
find the answer, you can try with different keywords.
|
|
|
|
Always describe what you see and plan your next steps clearly. When taking
|
|
actions, explain what you're doing and why. When the answer to the task is
|
|
found, call `generate_response` to finish the process. Only call
|
|
`generate_response` when answer is found. You should not respond any next steps
|
|
in `generate_response`. Complete all steps and then call `generate_response`.
|
|
|
|
User's email address is {inbox_address}
|
|
|
|
Today's date is {query_date}
|
|
|
|
"""
|
|
|
|
|
|
async def run_email_search_agent(
|
|
task: Dict,
|
|
model: ChatModelBase,
|
|
auxiliary_models: Dict[str, ChatModelBase],
|
|
) -> WorkflowOutput: # noqa: PLR0915
|
|
"""A workflow function using the Email Search agent to solve tasks.
|
|
|
|
Args:
|
|
task (Dict): The task to be solved.
|
|
Should contain fields from QueryModel.
|
|
model (TrinityChatModel): The language model to use.
|
|
|
|
Returns:
|
|
WorkflowOutput: The output containing the agent's response.
|
|
"""
|
|
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
|
|
|
|
# Parse task data
|
|
query = QueryModel.model_validate(task)
|
|
question = task.get("question", task.get("task_desc", ""))
|
|
|
|
# Get workflow arguments with defaults
|
|
workflow_args = task.get("workflow_args", {})
|
|
max_turns = int(workflow_args.get("max_turns", 10))
|
|
|
|
# Format system prompt
|
|
system_prompt = SYSTEM_PROMPT.format(
|
|
max_turns=max_turns,
|
|
inbox_address=query.inbox_address,
|
|
query_date=query.query_date,
|
|
)
|
|
|
|
# Create EmailSearchAgent
|
|
agent = EmailSearchAgent(
|
|
name="email_search_agent",
|
|
sys_prompt=system_prompt,
|
|
model=model,
|
|
formatter=OpenAIChatFormatter(),
|
|
max_iters=max_turns,
|
|
)
|
|
|
|
# Reset agent state for a new rollout
|
|
await agent.reset()
|
|
|
|
# Run the agent with structured output
|
|
response = await agent.reply(
|
|
msg=Msg("user", question, role="user"),
|
|
structured_model=AnswerModel,
|
|
)
|
|
|
|
# Extract answer and sources from response metadata
|
|
answer_and_sources = response.metadata or {}
|
|
if not answer_and_sources:
|
|
# Fallback: try to parse from content
|
|
answer_and_sources = {
|
|
"answer": response.get_text_content() or "",
|
|
"sources": [],
|
|
}
|
|
|
|
# Store agent state for judge function
|
|
# We'll pass this through the response metadata
|
|
response_metadata = {
|
|
"answer_and_sources": answer_and_sources,
|
|
"query": query.model_dump(),
|
|
"message_id_list": agent.message_id_list,
|
|
"ever_read_message_ids": agent.ever_read_message_ids,
|
|
# Estimate actual_turns from memory length
|
|
"actual_turns": (
|
|
max(1, (len(agent.memory.content) - 1) // 2)
|
|
if len(agent.memory.content) > 1
|
|
else 1
|
|
),
|
|
}
|
|
|
|
# Update response metadata
|
|
if response.metadata is None:
|
|
response.metadata = {}
|
|
response.metadata.update(response_metadata)
|
|
|
|
return WorkflowOutput(
|
|
response=response,
|
|
)
|
|
|
|
|
|
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
|
|
"""Calculate partial rewards based on rubric."""
|
|
partial_rewards = 0.0
|
|
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
|
|
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
|
|
partial_rewards += 0.1 if rubric.sources_correct else 0
|
|
return partial_rewards
|
|
|
|
|
|
def _calculate_correct_answer_reward(
|
|
rubric: FinalRubric,
|
|
max_turns: int,
|
|
) -> float:
|
|
"""Calculate reward for correct answers."""
|
|
reward = 1.0
|
|
reward += 0.3 if rubric.sources_correct else 0
|
|
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
|
|
reward += 0.1 * (1 - rubric.num_turns / max_turns)
|
|
return reward
|
|
|
|
|
|
def _initialize_rubric(
|
|
answer: str,
|
|
sources: list[str],
|
|
actual_turns: int,
|
|
query: QueryModel,
|
|
message_id_list: list[str],
|
|
ever_read_message_ids: list[str],
|
|
) -> FinalRubric:
|
|
"""Initialize and populate rubric with basic information."""
|
|
rubric = FinalRubric()
|
|
rubric.attempted_answer = answer is not None and answer != ""
|
|
rubric.returned_i_dont_know = answer == "I don't know"
|
|
rubric.num_sources = len(sources)
|
|
rubric.num_turns = actual_turns
|
|
|
|
if len(query.message_ids) > 0:
|
|
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
|
|
rubric.ever_read_right_email = (
|
|
query.message_ids[0] in ever_read_message_ids
|
|
)
|
|
rubric.sources_correct = query.message_ids[0] in sources
|
|
return rubric
|
|
|
|
|
|
async def email_search_judge(
|
|
task: Dict,
|
|
response: Msg,
|
|
auxiliary_models: Dict[str, ChatModelBase],
|
|
) -> JudgeOutput:
|
|
"""A judge function to calculate reward based on agent's response.
|
|
|
|
Args:
|
|
task (Dict): The task information for the corresponding workflow.
|
|
response (Msg): The response generated by the corresponding workflow.
|
|
auxiliary_models (Dict[str, ChatModelBase]):
|
|
A dictionary of additional chat models available for LLM-as-a-Judge
|
|
usage. The keys are model names, and the values are the
|
|
corresponding ChatModelBase instances.
|
|
|
|
Returns:
|
|
JudgeOutput: The reward value assigned by the judge function.
|
|
"""
|
|
# Extract metadata from response
|
|
metadata = response.metadata or {}
|
|
answer_and_sources = metadata.get("answer_and_sources", {})
|
|
query_dict = metadata.get("query", {})
|
|
message_id_list = metadata.get("message_id_list", [])
|
|
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
|
|
actual_turns = metadata.get("actual_turns", 0)
|
|
|
|
# Parse query model
|
|
if not query_dict:
|
|
query_dict = task
|
|
query = QueryModel.model_validate(query_dict)
|
|
|
|
# Get arguments
|
|
workflow_args = task.get("workflow_args", {})
|
|
max_turns = int(workflow_args.get("max_turns", 10))
|
|
|
|
# Extract answer and sources
|
|
try:
|
|
answer = answer_and_sources.get("answer", None)
|
|
sources = answer_and_sources.get("sources", [])
|
|
except Exception:
|
|
result = {"accuracy": 0.0, "format": -1.0}
|
|
return JudgeOutput(
|
|
reward=sum(result.values()),
|
|
metrics=result,
|
|
)
|
|
|
|
if answer is None:
|
|
result = {"accuracy": 0.0, "format": -1.0}
|
|
return JudgeOutput(
|
|
reward=sum(result.values()),
|
|
metrics=result,
|
|
)
|
|
|
|
# Initialize rubric
|
|
rubric = _initialize_rubric(
|
|
answer,
|
|
sources,
|
|
actual_turns,
|
|
query,
|
|
message_id_list,
|
|
ever_read_message_ids,
|
|
)
|
|
|
|
# Judge correctness using LLM-as-a-Judge
|
|
try:
|
|
judge_model = (
|
|
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
|
|
if auxiliary_models
|
|
else None
|
|
)
|
|
judge_response = await judge_correctness(
|
|
answer,
|
|
query,
|
|
judge_model,
|
|
)
|
|
rubric.answer_correct = judge_response
|
|
except Exception as e:
|
|
logger.error("Error judging correctness: %s", e)
|
|
rubric.answer_correct = False
|
|
|
|
# Calculate rewards
|
|
partial_rewards = _calculate_partial_rewards(rubric)
|
|
|
|
if rubric.attempted_answer and not rubric.answer_correct:
|
|
result = {"accuracy": -1.0, "format": partial_rewards}
|
|
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
|
|
result = {"accuracy": 0.0, "format": partial_rewards}
|
|
elif rubric.answer_correct:
|
|
reward = _calculate_correct_answer_reward(rubric, max_turns)
|
|
result = {"accuracy": 1.0, "format": reward}
|
|
else:
|
|
result = {"accuracy": 0.0, "format": 0.0}
|
|
|
|
metrics = result.copy()
|
|
metrics.update({"actual_turns": actual_turns})
|
|
|
|
return JudgeOutput(
|
|
reward=sum(result.values()),
|
|
metrics=metrics,
|
|
)
|
|
|
|
|
|
# LLM-as-a-judge
|
|
|
|
|
|
async def judge_correctness(
|
|
answer: str,
|
|
query: QueryModel,
|
|
judge: ChatModelBase,
|
|
) -> bool:
|
|
"""Use an LLM to decide whether *answer* matches *query.answer*.
|
|
|
|
Returns a boolean *accept* flag used for scoring.
|
|
"""
|
|
|
|
system_prompt = """You are given a question, the reference answer
|
|
(labelled **Reference answer**), and an answer generated by an AI assistant
|
|
(labelled **AI answer**).
|
|
|
|
Follow these steps to decide whether the AI answer should be accepted:
|
|
1. Identify EXACTLY what information the **question** is asking for
|
|
(e.g. who, what, when, where, why, how, quantity, etc.).
|
|
2. From the **Reference answer**, extract ONLY the facts that are required
|
|
to directly satisfy the information need identified in step 1. Treat all
|
|
other facts as non-essential context.
|
|
3. Verify that every essential fact from step 2 appears in the **AI answer**
|
|
with the same meaning. Differences in wording, order, or additional
|
|
non-conflicting details are allowed.
|
|
4. If any essential fact is missing or contradicted in the **AI answer**,
|
|
then *accept* must be **false**. Otherwise *accept* must be **true**.
|
|
|
|
Important: Do NOT penalise the **AI answer** for omitting non-essential
|
|
facts that appear in the **Reference answer**. The answer should only be
|
|
rejected for errors or omissions in the information explicitly requested by
|
|
the question.
|
|
|
|
Return your judgement **accept** from **true** and **false**. Do not return
|
|
any other text or formatting.
|
|
"""
|
|
prompt = (
|
|
f"Question: {query.question}\n"
|
|
f"Reference answer: {query.answer}\n"
|
|
f"AI answer: {answer}"
|
|
)
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
chat_response = await judge(messages)
|
|
|
|
# Extract text content from ChatResponse
|
|
result_parts = []
|
|
for block in chat_response.content:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
result_parts.append(str(block.get("text", "")))
|
|
result = "".join(result_parts)
|
|
logger.info("LLM judge response: %s", result)
|
|
|
|
return "true" in result.lower()
|
|
|
|
|
|
# End of LLM-as-a-judge
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
"config.yaml",
|
|
)
|
|
dataset = DatasetConfig(
|
|
path="/path/to/enron_emails_dataset",
|
|
split="train",
|
|
)
|
|
tuner_model = TunerModelConfig(
|
|
model_path="Qwen/Qwen3-4B-Instruct-2507",
|
|
max_model_len=20480,
|
|
max_tokens=4096,
|
|
inference_engine_num=4,
|
|
reasoning_parser=None,
|
|
)
|
|
aux_models = {
|
|
"judge": TunerModelConfig(
|
|
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
|
max_model_len=2500,
|
|
max_tokens=2048,
|
|
inference_engine_num=1,
|
|
tensor_parallel_size=2,
|
|
tool_call_parser=None,
|
|
reasoning_parser=None,
|
|
),
|
|
}
|
|
algorithm = AlgorithmConfig(
|
|
algorithm_type="multi_step_grpo",
|
|
group_size=8,
|
|
batch_size=64,
|
|
learning_rate=1e-6,
|
|
)
|
|
tune(
|
|
workflow_func=run_email_search_agent,
|
|
judge_func=email_search_judge,
|
|
train_dataset=dataset,
|
|
model=tuner_model,
|
|
auxiliary_models=aux_models,
|
|
algorithm=algorithm,
|
|
config_path=config_path,
|
|
)
|