152 lines
4.3 KiB
Python
152 lines
4.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Example of training a FrozenLake agent with Trinity-RFT."""
|
|
import os
|
|
from typing import Dict
|
|
from _frozenlake_agent import FrozenLakeAgent
|
|
from _frozenlake_env import FrozenLakeEnv
|
|
from agentscope.message import Msg
|
|
from agentscope.tuner import (
|
|
tune,
|
|
WorkflowOutput,
|
|
DatasetConfig,
|
|
TunerModelConfig,
|
|
AlgorithmConfig,
|
|
)
|
|
from agentscope.model import ChatModelBase
|
|
|
|
|
|
async def run_frozen_lake(
|
|
task: Dict,
|
|
model: ChatModelBase,
|
|
auxiliary_models: Dict[str, ChatModelBase],
|
|
) -> WorkflowOutput:
|
|
"""A workflow function using the FrozenLake agent to solve tasks.
|
|
|
|
Args:
|
|
task (Dict): The task to be solved, containing environment parameters
|
|
like size, p, seed, is_slippery, etc.
|
|
model (ChatModelBase): The language model to use.
|
|
|
|
Returns:
|
|
WorkflowOutput: The workflow output containing the reward, response and
|
|
metrics.
|
|
"""
|
|
|
|
assert len(auxiliary_models) == 0, "No auxiliary models are needed"
|
|
|
|
# Extract workflow arguments from task or use defaults
|
|
workflow_args = task.get("workflow_args", {})
|
|
if not workflow_args:
|
|
workflow_args = task
|
|
|
|
env_max_steps = workflow_args.get("env_max_steps", 8)
|
|
agent_max_steps = workflow_args.get("agent_max_steps", 10)
|
|
is_slippery = workflow_args.get("is_slippery", False)
|
|
desc = workflow_args.get("desc", None)
|
|
|
|
# Extract task-specific arguments (for environment generation)
|
|
size = task.get("size", 8)
|
|
p = task.get("p", 0.8)
|
|
seed = task.get("seed", 42)
|
|
|
|
# Initialize agent and environment
|
|
agent = FrozenLakeAgent(model=model, max_steps=agent_max_steps)
|
|
env = FrozenLakeEnv(
|
|
max_steps=env_max_steps,
|
|
desc=desc,
|
|
is_slippery=is_slippery,
|
|
size=size,
|
|
p=p,
|
|
seed=seed,
|
|
)
|
|
|
|
# Reset environment with task parameters
|
|
observation, _ = env.reset(task)
|
|
observation_str = str(observation)
|
|
rewards = []
|
|
step_count = 0
|
|
done = False
|
|
terminate_reason = None
|
|
|
|
# Run agent-environment interaction loop
|
|
for _ in range(agent_max_steps):
|
|
step_count += 1
|
|
try:
|
|
# get prompt
|
|
prompt = agent.get_prompt(observation_str)
|
|
|
|
response = await agent.reply(msg=Msg("user", prompt, role="user"))
|
|
|
|
# record action and observation
|
|
action = agent.get_action(response)
|
|
agent.update_state(action=action, observation=observation_str)
|
|
|
|
except Exception as e:
|
|
terminate_reason = f"agent_error: {str(e)}"
|
|
break
|
|
|
|
# environment step
|
|
observation, reward, done, _ = env.step(action)
|
|
observation_str = str(observation)
|
|
rewards.append(reward)
|
|
|
|
if done:
|
|
terminate_reason = "success" if env.success() else "hole"
|
|
break
|
|
|
|
if terminate_reason is None:
|
|
terminate_reason = "max_steps_reached"
|
|
|
|
final_reward = sum(rewards)
|
|
final_observation = observation_str
|
|
|
|
# Create response message with environment information
|
|
response_content = (
|
|
f"Final observation:\n{final_observation}\n"
|
|
f"Total reward: {final_reward}\n"
|
|
f"Steps taken: {step_count}\n"
|
|
f"Terminate reason: {terminate_reason}"
|
|
)
|
|
|
|
final_response = Msg("assistant", response_content, role="assistant")
|
|
|
|
return WorkflowOutput(
|
|
reward=final_reward,
|
|
response=final_response,
|
|
metrics={
|
|
"env_steps": float(step_count),
|
|
"env_done": float(done),
|
|
},
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = DatasetConfig(
|
|
path="/path/to/frozenlake",
|
|
split="train",
|
|
)
|
|
tuner_model = TunerModelConfig(
|
|
model_path="Qwen/Qwen2.5-3B-Instruct",
|
|
max_model_len=25600,
|
|
max_tokens=2048,
|
|
inference_engine_num=6,
|
|
reasoning_parser=None,
|
|
)
|
|
algorithm = AlgorithmConfig(
|
|
algorithm_type="multi_step_grpo",
|
|
group_size=16,
|
|
batch_size=32,
|
|
learning_rate=1e-6,
|
|
)
|
|
config_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
"config.yaml",
|
|
) # define some default parameters
|
|
tune(
|
|
workflow_func=run_frozen_lake,
|
|
model=tuner_model,
|
|
train_dataset=dataset,
|
|
algorithm=algorithm,
|
|
config_path=config_path,
|
|
)
|