Files
evotraders/tuner/math_agent/main.py

131 lines
3.9 KiB
Python

# -*- coding: utf-8 -*-
"""Example of training a ReAct agent on GSM8K with Trinity-RFT."""
from typing import Dict
from agentscope.tuner import (
tune,
DatasetConfig,
WorkflowOutput,
JudgeOutput,
TunerModelConfig,
AlgorithmConfig,
)
from agentscope.agent import ReActAgent
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
) -> WorkflowOutput:
"""A simple workflow function using the ReAct agent to solve tasks.
Args:
task (`Dict`): The task to be solved.
model (`OpenAIChatModel`): The language model to use.
auxiliary_models (`Dict[str, OpenAIChatModel]`):
A dictionary of additional chat models available for
LLM-as-a-Judge. Not used in this workflow.
Returns:
`WorkflowOutput`: The workflow output containing the agent's response.
"""
assert (
auxiliary_models is None or len(auxiliary_models) == 0
), "No auxiliary models are used in this workflow."
sys_prompt = (
"You are an agent specialized in solving math problems with tools. "
"Please solve the math problem given to you. You can write and "
"execute Python code to perform calculation or verify your answer. "
"You should return your final answer within \\boxed{{}}."
)
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
enable_meta_tool=True,
formatter=OpenAIChatFormatter(),
)
response = await agent.reply(
msg=Msg("user", task["question"], role="user"),
)
return WorkflowOutput(
response=response,
)
async def gsm8k_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
) -> JudgeOutput:
"""A simple judge function to calculate reward based on agent's response.
Args:
task (`Dict`): The task information for the corresponding workflow.
response (`Msg`): The response generated by the corresponding workflow.
auxiliary_models (`Dict[str, OpenAIChatModel]`):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding OpenAIChatModel instances.
Returns:
`JudgeOutput`: The reward value assigned by the judge function.
"""
from trinity.common.rewards.math_reward import MathBoxedRewardFn
assert (
auxiliary_models is None or len(auxiliary_models) == 0
), "No auxiliary models are used in this workflow."
reward_fn = MathBoxedRewardFn()
# parse truth from gsm8k raw text
truth = task["answer"]
if isinstance(truth, str) and "####" in truth:
truth = truth.split("####")[1].strip()
else:
truth = str(truth)
# parse answer from response message
result = response.get_text_content()
reward_dict = reward_fn(
response=result,
truth=truth,
)
return JudgeOutput(
reward=sum(reward_dict.values()),
metrics=reward_dict,
)
if __name__ == "__main__":
dataset = DatasetConfig(
path="openai/gsm8k",
name="main",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-0.6B",
max_model_len=24576,
max_tokens=16384,
temperature=1.0,
inference_engine_num=4,
tensor_parallel_size=1,
)
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
learning_rate=1e-6,
batch_size=32,
)
tune(
workflow_func=run_react_agent,
judge_func=gsm8k_judge,
train_dataset=dataset,
model=tuner_model,
algorithm=algorithm,
)