54 lines
2.4 KiB
YAML
54 lines
2.4 KiB
YAML
# Please refer to https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field.
|
|
project: AgentScope
|
|
name: GSM8K-Qwen3-0.6B
|
|
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set
|
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
|
|
algorithm:
|
|
algorithm_type: multi_step_grpo # a GRPO-based algorithm for multi-step reasoning
|
|
repeat_times: 8 # repeat each training sample 8 times
|
|
model:
|
|
# path to the pre-trained model, default to Qwen/Qwen3-0.6B if TRINITY_MODEL_PATH not set
|
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B}
|
|
# maximum tokens generated in response
|
|
max_response_tokens: 16384
|
|
# maximum token length for both input and output
|
|
# if you face OOM, try to reduce max_model_len and max_response_tokens
|
|
max_model_len: 24576
|
|
temperature: 1.0
|
|
cluster:
|
|
node_num: 1 # cluster with 1 node
|
|
gpu_per_node: 8 # each node has 8 GPUs
|
|
buffer:
|
|
total_epochs: 1 # run taskset for 1 epoch
|
|
batch_size: 32 # each step contains 32 samples from taskset
|
|
train_batch_size: 256 # trainer batch size is 256 (multi-step reasoning generate more training samples)
|
|
explorer_input:
|
|
taskset: # define the taskset for rollout
|
|
name: gsm8k
|
|
path: 'openai/gsm8k'
|
|
subset_name: 'main'
|
|
split: 'train'
|
|
explorer:
|
|
runner_per_model: 16 # each model has 16 runners for parallel rollout
|
|
max_timeout: 600 # max timeout for each rollout is 600 seconds
|
|
rollout_model:
|
|
engine_num: 4 # setup 4 vllm inference model instances
|
|
tensor_parallel_size: 1 # each model instance uses tensor parallel size of 1
|
|
enable_openai_api: true # some parameters to provide openai-style API, don't change them
|
|
enable_history: true
|
|
enable_auto_tool_choice: true
|
|
# Qwen3 series tool_call_parser and reasoning_parser, if you use other models, please adjust accordingly
|
|
tool_call_parser: hermes
|
|
reasoning_parser: deepseek_r1
|
|
synchronizer:
|
|
sync_style: dynamic_by_explorer
|
|
sync_method: 'nccl'
|
|
sync_interval: 1
|
|
sync_timeout: 1800 # wait for 30 minutes
|
|
trainer:
|
|
save_interval: 100 # save checkpoint every 100 steps
|
|
use_dynamic_bsz: true
|
|
ulysses_sequence_parallel_size: 1 # use sequence parallelism to reduce memory usage
|
|
monitor:
|
|
monitor_type: tensorboard # here we use tensorboard, you can also use wandb, mlflow or swanlab
|