142 lines
4.2 KiB
Python
142 lines
4.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Example of training a ReAct math-agent with configurable task selector."""
|
|
from typing import Dict
|
|
import argparse
|
|
|
|
from agentscope.tuner import (
|
|
tune,
|
|
WorkflowOutput,
|
|
JudgeOutput,
|
|
TunerModelConfig,
|
|
AlgorithmConfig,
|
|
)
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.model import OpenAIChatModel
|
|
from agentscope.formatter import OpenAIChatFormatter
|
|
from agentscope.message import Msg
|
|
|
|
|
|
async def run_react_agent(
|
|
task: Dict,
|
|
model: OpenAIChatModel,
|
|
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
|
) -> WorkflowOutput:
|
|
"""A simple workflow function using the ReAct agent to solve tasks.
|
|
|
|
Args:
|
|
task (`Dict`): The task to be solved.
|
|
model (`OpenAIChatModel`): The language model to use.
|
|
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
|
A dictionary of additional chat models available for
|
|
LLM-as-a-Judge. Not used in this workflow.
|
|
|
|
Returns:
|
|
`WorkflowOutput`: The workflow output containing the agent's response.
|
|
"""
|
|
assert (
|
|
auxiliary_models is None or len(auxiliary_models) == 0
|
|
), "No auxiliary models are used in this workflow."
|
|
|
|
sys_prompt = (
|
|
"You are an agent specialized in solving math problems with tools. "
|
|
"Please solve the math problem given to you. You can write and "
|
|
"execute Python code to perform calculation or verify your answer. "
|
|
"You should return your final answer within \\boxed{{}}."
|
|
)
|
|
agent = ReActAgent(
|
|
name="react_agent",
|
|
sys_prompt=sys_prompt,
|
|
model=model,
|
|
enable_meta_tool=True,
|
|
formatter=OpenAIChatFormatter(),
|
|
)
|
|
response = await agent.reply(
|
|
msg=Msg("user", task["question"], role="user"),
|
|
)
|
|
return WorkflowOutput(
|
|
response=response,
|
|
)
|
|
|
|
|
|
async def gsm8k_judge(
|
|
task: Dict,
|
|
response: Msg,
|
|
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
|
) -> JudgeOutput:
|
|
"""A simple judge function to calculate reward based on agent's response.
|
|
|
|
Args:
|
|
task (`Dict`): The task information for the corresponding workflow.
|
|
response (`Msg`): The response generated by the corresponding workflow.
|
|
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
|
A dictionary of additional chat models available for LLM-as-a-Judge
|
|
usage. The keys are model names, and the values are the
|
|
corresponding OpenAIChatModel instances.
|
|
|
|
Returns:
|
|
`JudgeOutput`: The reward value assigned by the judge function.
|
|
"""
|
|
from trinity.common.rewards.math_reward import MathBoxedRewardFn
|
|
|
|
assert (
|
|
auxiliary_models is None or len(auxiliary_models) == 0
|
|
), "No auxiliary models are used in this workflow."
|
|
|
|
reward_fn = MathBoxedRewardFn()
|
|
# parse truth from gsm8k raw text
|
|
truth = task["answer"]
|
|
if isinstance(truth, str) and "####" in truth:
|
|
truth = truth.split("####")[1].strip()
|
|
else:
|
|
truth = str(truth)
|
|
# parse answer from response message
|
|
result = response.get_text_content()
|
|
reward_dict = reward_fn(
|
|
response=result,
|
|
truth=truth,
|
|
)
|
|
return JudgeOutput(
|
|
reward=sum(reward_dict.values()),
|
|
metrics=reward_dict,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# We recommend using YAML for data-centric experiments.
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Train math-agent with different task selectors",
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default="config_random.yaml",
|
|
help="Path to the configuration YAML file",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
tuner_model = TunerModelConfig(
|
|
model_path="Qwen/Qwen3-0.6B",
|
|
max_model_len=24576,
|
|
max_tokens=16384,
|
|
temperature=1.0,
|
|
inference_engine_num=4,
|
|
tensor_parallel_size=1,
|
|
)
|
|
|
|
algorithm = AlgorithmConfig(
|
|
algorithm_type="multi_step_grpo",
|
|
group_size=8,
|
|
learning_rate=1e-6,
|
|
eval_interval_steps=20,
|
|
batch_size=16,
|
|
)
|
|
|
|
tune(
|
|
workflow_func=run_react_agent,
|
|
judge_func=gsm8k_judge,
|
|
config_path=args.config,
|
|
model=tuner_model,
|
|
algorithm=algorithm,
|
|
)
|