Add example for data augmentation in tuner (#98)
This commit is contained in:
141
tuner/data_augment/main.py
Normal file
141
tuner/data_augment/main.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Example of training a ReAct math-agent with configurable task selector."""
|
||||
from typing import Dict
|
||||
import argparse
|
||||
|
||||
from agentscope.tuner import (
|
||||
tune,
|
||||
WorkflowOutput,
|
||||
JudgeOutput,
|
||||
TunerModelConfig,
|
||||
AlgorithmConfig,
|
||||
)
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.message import Msg
|
||||
|
||||
|
||||
async def run_react_agent(
|
||||
task: Dict,
|
||||
model: OpenAIChatModel,
|
||||
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
||||
) -> WorkflowOutput:
|
||||
"""A simple workflow function using the ReAct agent to solve tasks.
|
||||
|
||||
Args:
|
||||
task (`Dict`): The task to be solved.
|
||||
model (`OpenAIChatModel`): The language model to use.
|
||||
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
||||
A dictionary of additional chat models available for
|
||||
LLM-as-a-Judge. Not used in this workflow.
|
||||
|
||||
Returns:
|
||||
`WorkflowOutput`: The workflow output containing the agent's response.
|
||||
"""
|
||||
assert (
|
||||
auxiliary_models is None or len(auxiliary_models) == 0
|
||||
), "No auxiliary models are used in this workflow."
|
||||
|
||||
sys_prompt = (
|
||||
"You are an agent specialized in solving math problems with tools. "
|
||||
"Please solve the math problem given to you. You can write and "
|
||||
"execute Python code to perform calculation or verify your answer. "
|
||||
"You should return your final answer within \\boxed{{}}."
|
||||
)
|
||||
agent = ReActAgent(
|
||||
name="react_agent",
|
||||
sys_prompt=sys_prompt,
|
||||
model=model,
|
||||
enable_meta_tool=True,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
)
|
||||
response = await agent.reply(
|
||||
msg=Msg("user", task["question"], role="user"),
|
||||
)
|
||||
return WorkflowOutput(
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
async def gsm8k_judge(
|
||||
task: Dict,
|
||||
response: Msg,
|
||||
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
||||
) -> JudgeOutput:
|
||||
"""A simple judge function to calculate reward based on agent's response.
|
||||
|
||||
Args:
|
||||
task (`Dict`): The task information for the corresponding workflow.
|
||||
response (`Msg`): The response generated by the corresponding workflow.
|
||||
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
||||
A dictionary of additional chat models available for LLM-as-a-Judge
|
||||
usage. The keys are model names, and the values are the
|
||||
corresponding OpenAIChatModel instances.
|
||||
|
||||
Returns:
|
||||
`JudgeOutput`: The reward value assigned by the judge function.
|
||||
"""
|
||||
from trinity.common.rewards.math_reward import MathBoxedRewardFn
|
||||
|
||||
assert (
|
||||
auxiliary_models is None or len(auxiliary_models) == 0
|
||||
), "No auxiliary models are used in this workflow."
|
||||
|
||||
reward_fn = MathBoxedRewardFn()
|
||||
# parse truth from gsm8k raw text
|
||||
truth = task["answer"]
|
||||
if isinstance(truth, str) and "####" in truth:
|
||||
truth = truth.split("####")[1].strip()
|
||||
else:
|
||||
truth = str(truth)
|
||||
# parse answer from response message
|
||||
result = response.get_text_content()
|
||||
reward_dict = reward_fn(
|
||||
response=result,
|
||||
truth=truth,
|
||||
)
|
||||
return JudgeOutput(
|
||||
reward=sum(reward_dict.values()),
|
||||
metrics=reward_dict,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# We recommend using YAML for data-centric experiments.
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train math-agent with different task selectors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default="config_random.yaml",
|
||||
help="Path to the configuration YAML file",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tuner_model = TunerModelConfig(
|
||||
model_path="Qwen/Qwen3-0.6B",
|
||||
max_model_len=24576,
|
||||
max_tokens=16384,
|
||||
temperature=1.0,
|
||||
inference_engine_num=4,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
algorithm = AlgorithmConfig(
|
||||
algorithm_type="multi_step_grpo",
|
||||
group_size=8,
|
||||
learning_rate=1e-6,
|
||||
eval_interval_steps=20,
|
||||
batch_size=16,
|
||||
)
|
||||
|
||||
tune(
|
||||
workflow_func=run_react_agent,
|
||||
judge_func=gsm8k_judge,
|
||||
config_path=args.config,
|
||||
model=tuner_model,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
Reference in New Issue
Block a user