Add Math Agent (Quick Start for AgentScope Tuner) (#102)
This commit is contained in:
130
tuner/math_agent/main.py
Normal file
130
tuner/math_agent/main.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Example of training a ReAct agent on GSM8K with Trinity-RFT."""
|
||||
from typing import Dict
|
||||
|
||||
from agentscope.tuner import (
|
||||
tune,
|
||||
DatasetConfig,
|
||||
WorkflowOutput,
|
||||
JudgeOutput,
|
||||
TunerModelConfig,
|
||||
AlgorithmConfig,
|
||||
)
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.model import OpenAIChatModel
|
||||
from agentscope.formatter import OpenAIChatFormatter
|
||||
from agentscope.message import Msg
|
||||
|
||||
|
||||
async def run_react_agent(
|
||||
task: Dict,
|
||||
model: OpenAIChatModel,
|
||||
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
||||
) -> WorkflowOutput:
|
||||
"""A simple workflow function using the ReAct agent to solve tasks.
|
||||
|
||||
Args:
|
||||
task (`Dict`): The task to be solved.
|
||||
model (`OpenAIChatModel`): The language model to use.
|
||||
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
||||
A dictionary of additional chat models available for
|
||||
LLM-as-a-Judge. Not used in this workflow.
|
||||
|
||||
Returns:
|
||||
`WorkflowOutput`: The workflow output containing the agent's response.
|
||||
"""
|
||||
assert (
|
||||
auxiliary_models is None or len(auxiliary_models) == 0
|
||||
), "No auxiliary models are used in this workflow."
|
||||
|
||||
sys_prompt = (
|
||||
"You are an agent specialized in solving math problems with tools. "
|
||||
"Please solve the math problem given to you. You can write and "
|
||||
"execute Python code to perform calculation or verify your answer. "
|
||||
"You should return your final answer within \\boxed{{}}."
|
||||
)
|
||||
agent = ReActAgent(
|
||||
name="react_agent",
|
||||
sys_prompt=sys_prompt,
|
||||
model=model,
|
||||
enable_meta_tool=True,
|
||||
formatter=OpenAIChatFormatter(),
|
||||
)
|
||||
response = await agent.reply(
|
||||
msg=Msg("user", task["question"], role="user"),
|
||||
)
|
||||
return WorkflowOutput(
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
async def gsm8k_judge(
|
||||
task: Dict,
|
||||
response: Msg,
|
||||
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
||||
) -> JudgeOutput:
|
||||
"""A simple judge function to calculate reward based on agent's response.
|
||||
|
||||
Args:
|
||||
task (`Dict`): The task information for the corresponding workflow.
|
||||
response (`Msg`): The response generated by the corresponding workflow.
|
||||
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
||||
A dictionary of additional chat models available for LLM-as-a-Judge
|
||||
usage. The keys are model names, and the values are the
|
||||
corresponding OpenAIChatModel instances.
|
||||
|
||||
Returns:
|
||||
`JudgeOutput`: The reward value assigned by the judge function.
|
||||
"""
|
||||
from trinity.common.rewards.math_reward import MathBoxedRewardFn
|
||||
|
||||
assert (
|
||||
auxiliary_models is None or len(auxiliary_models) == 0
|
||||
), "No auxiliary models are used in this workflow."
|
||||
|
||||
reward_fn = MathBoxedRewardFn()
|
||||
# parse truth from gsm8k raw text
|
||||
truth = task["answer"]
|
||||
if isinstance(truth, str) and "####" in truth:
|
||||
truth = truth.split("####")[1].strip()
|
||||
else:
|
||||
truth = str(truth)
|
||||
# parse answer from response message
|
||||
result = response.get_text_content()
|
||||
reward_dict = reward_fn(
|
||||
response=result,
|
||||
truth=truth,
|
||||
)
|
||||
return JudgeOutput(
|
||||
reward=sum(reward_dict.values()),
|
||||
metrics=reward_dict,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = DatasetConfig(
|
||||
path="openai/gsm8k",
|
||||
name="main",
|
||||
split="train",
|
||||
)
|
||||
tuner_model = TunerModelConfig(
|
||||
model_path="Qwen/Qwen3-0.6B",
|
||||
max_model_len=24576,
|
||||
max_tokens=16384,
|
||||
temperature=1.0,
|
||||
inference_engine_num=4,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
algorithm = AlgorithmConfig(
|
||||
algorithm_type="multi_step_grpo",
|
||||
group_size=8,
|
||||
learning_rate=1e-6,
|
||||
batch_size=32,
|
||||
)
|
||||
tune(
|
||||
workflow_func=run_react_agent,
|
||||
judge_func=gsm8k_judge,
|
||||
train_dataset=dataset,
|
||||
model=tuner_model,
|
||||
algorithm=algorithm,
|
||||
)
|
||||
Reference in New Issue
Block a user