Add Tuner learn_to_ask example (#101)

This commit is contained in:
chenyushuo
2026-01-16 19:24:46 +08:00
committed by GitHub
parent 5855c5161b
commit 3821fb04ac
10 changed files with 1643 additions and 0 deletions

View File

@@ -0,0 +1,406 @@
# Training Learn2Ask with AgentScope-Tuner
This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
---
## Task Setting
In this example, given a user's chief complaint, the medical assistant agent proactively asks targeted questions to gather sufficient symptom information for a comprehensive assessment of the user's health condition. The querying process should be efficient: the agent must optimize question quality, and terminate the interview as soon as the collected information is adequate for subsequent clinical assessment or decision-making.
Here we use the `ReActAgent` for this task and no tools are required.
---
## Hardware Requirements
- **Training with GPUs**: At least **8 H20 GPUs** (or equivalent).
- **Training without GPUs**: You can use the **[Tinker](https://thinkingmachines.ai/tinker/)** without any GPUs.
> 💡 All code and configuration files are located in:
> `tuner/learn_to_ask/`
Key files:
- Workflow & Training: `tuner/learn_to_ask/main.py`
- Prompts: `tuner/learn_to_ask/prompt.py`
- Training config: `tuner/learn_to_ask/config.yaml`
- Data preparation scripts: `tuner/learn_to_ask/data_prepare/`
---
## Dataset Preparation
> [!NOTE]
> In this example, we use an open-source dataset directly for training. In practice, however, you would typically start by collecting interaction logs between your deployed agent and users. After filtering these raw logs to curate a high-quality dataset, you can follow the same pipeline to enhance your agents proactive capabilities using AgentTune. Happy tuning!
### 1.1 Download the Dataset
Download the **[RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv)** dataset (in `.jsonl` format).
You can use the following python scripts to download the dataset:
```python
from huggingface_hub import snapshot_download
# Download to local directory, e.g., `./tuner/learn_to_ask/data`
local_dir = "./tuner/learn_to_ask/data"
snapshot_download(
repo_id="datajuicer/RealMedConv",
repo_type="dataset",
local_dir=local_dir,
)
```
Each line in `train_origin.jsonl` (or `test_origin.jsonl`) represents a complete doctor-patient conversation log, like this:
```json
{
"session_id": 35310,
"diagn": "Upper Respiratory Tract Infection",
"messages": [
{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"},
{"role": "user", "content": "I took Amoxicillin"},
...
{"role": "assistant", "content": "<med_search>"}
]
}
```
### 1.2 Preprocess the Data
You need to convert raw conversation logs into training samples. This involves two steps:
#### 🔹 Step A: Segment Conversations & Extract Labels
Split each conversation into **contextfuture pairs**, and extract ground-truth symptom information (`info_truth`) from what happens next.
```bash
python tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py \
--input_file /path/to/RealMedConv/train.jsonl \
--output_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
--model_path Qwen/Qwen2.5-32B-Instruct
```
#### 🔹 Step B: Build Final Training Dataset
Convert the processed samples into the final format used for training/testing.
```bash
python tuner/learn_to_ask/data_prepare/2_build_dataset.py \
--input_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
--output_file tuner/learn_to_ask/data/train.jsonl
```
---
### How It Works: ContextFuture Segmentation
For every turn in a conversation, we create a sample with:
- `messages`: The **observed dialogue history** up to that point (the *context*).
- `remaining_chat`: Everything that happens **after** that point (the *future*).
- A unique ID: `cid = {session_id}_{turn_index}`
Example output:
```json
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [ ... up to turn 7 ... ],
"remaining_chat": [ ... all future messages ... ]
}
```
### Extract Ground-Truth Labels
From `remaining_chat`, we automatically derive two key labels:
- `decision_truth`: Should the assistant **continue asking questions** (`"continue"`) or **stop** (`"stop"`)?
- `info_truth`: Structured list of symptoms mentioned later (used to compute reward signals during training).
Example:
```json
{
"decision_truth": "continue",
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, ..."
}
```
These labels power the reward functions $R_a$ (action accuracy) and $R_s$ (symptom coverage) during training.
---
## Code Implementation
### Agent Workflow
The workflow function `run_react_agent` implements how the `ReActAgent` works.
```python
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> WorkflowOutput:
assert (
len(auxiliary_models) == 1
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "prompt.py"),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if TRAIN_MODE == "Ra":
sys_prompt = module.rollout_prompt_med_Ra
else:
sys_prompt = module.rollout_prompt_med
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=None,
memory=InMemoryMemory(),
max_iters=1,
)
messages = format_messages(task["messages"])
response = await agent.reply(
[
Msg(name=x["role"], content=x["content"], role=x["role"])
for x in messages
],
)
return WorkflowOutput(
response=response,
)
```
### Judge Function
The judge function `learn2ask_judge` implements reward calculation using LLM-as-a-Judge:
```python
async def learn2ask_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> JudgeOutput:
assert (
len(auxiliary_models) == 1
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
response_text = response.get_text_content()
action_truth = (
task["decision_truth"] if "decision_truth" in task else "continue"
)
action_response = "stop" if "<stop />" in response_text else "continue"
if action_truth == action_response:
action_score = 1.0
if action_truth == "continue":
score_dict = await llm_reward(
task=task,
response=response_text,
auxiliary_models=auxiliary_models, # LLM-as-a-Judge
)
if score_dict != {}:
format_score = float(score_dict.get("format_score", 0.0))
content_score = float(score_dict.get("content_score", 0.0))
else:
format_score, content_score = 0.0, 0.0
else:
content_score = 1.0
format_score = 1.0 if response_text == "<stop />" else 0.0
else:
action_score, format_score, content_score = 0.0, 0.0, 0.0
if TRAIN_MODE == "Ra+Rs": # the default setting
final_reward = (
action_score * (1 + 2 * content_score) + format_score
if FUSION_MODE != "sum"
else action_score + content_score + format_score
)
elif TRAIN_MODE == "Ra": # for Ra only (without Rs)
final_reward = 2 * content_score + format_score
else: # for Rs only (without Ra)
final_reward = action_score * 3 + format_score
return JudgeOutput(
reward=final_reward,
metrics={"reward": final_reward},
)
```
This reward function considers:
- Action accuracy: `action_score`
- Question quality (Symptom coverage): `content_score`
- Format score: `format_score`
See [main.py](./main.py) for implementation details.
---
## Configure and Train the Model
### Option A: Edit Python Script (Simple)
Open `tuner/learn_to_ask/main.py` and adjust settings:
```python
if __name__ == "__main__":
train_mode = "Ra+Rs" # Use both action and symptom rewards
fusion_mode = "default" # How to combine rewards
dataset = DatasetConfig(path="tuner/learn_to_ask/data", split="train")
tuner_model = OpenAIChatModel(
model_path="Qwen/Qwen2.5-7B-Instruct",
max_model_len=8192,
tensor_parallel_size=1, # Adjust based on your GPU setup
...
)
auxiliary_models = {
AUXILIARY_MODEL_NAME: OpenAIChatModel(
model_path="Qwen/Qwen2.5-32B-Instruct", # Larger model for evaluation
tensor_parallel_size=2,
...
)
}
algorithm = AlgorithmConfig(
algorithm_type="grpo",
learning_rate=5e-7,
batch_size=64,
)
tune(...) # Starts training
```
### Option B: Use YAML Config (Advanced)
Edit `tuner/learn_to_ask/config.yaml` for more control.
#### 🌐 No GPU? Use Tinker!
If you dont have GPUs, enable the **Tinker backend** by setting:
```yaml
model:
tinker:
enable: true # ← Set this to true
```
Also, make sure to update the `model_path` in `tuner/learn_to_ask/main.py` to point to a model thats compatible with Tinker.
> 🔗 Learn more about Tinker Backend: [Tinker Backend Documentation](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/example_tinker_backend.html)
### Launch Training
```bash
python tuner/learn_to_ask/main.py
```
---
## Evaluation
Use the **rollout-and-evaluate pipeline**:
1. Generate responses on the test set.
2. Score them using a powerful evaluator model (`Qwen2.5-32B-Instruct`).
Run evaluation:
```bash
python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
--eval_model_path path/to/your/trained/model \
--grader_model_path Qwen/Qwen2.5-32B-Instruct \
--test_file_path tuner/learn_to_ask/data/test.jsonl \
--rollout_file_path path/to/rollout.jsonl \
--eval_file_path path/to/output.jsonl
```
> ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first.
> See: [Converting FSDP Checkpoints Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html)
---
## Experimental Results
We compared three approaches:
- **Base model**: `Qwen2.5-7B-Instruct` (no fine-tuning)
- **Trinity**: Direct response generation
- **As-tune (Learn2Ask)**: Uses a ReAct agent for proactive questioning
| Metric | Base Model | Trinity | As-tune (Learn2Ask) |
|--------------------------------------|-----------:|--------:|--------------------:|
| Avg. continue content | 0.436 | 0.496 | 0.509 |
| Win rate (continue content) | 0.122 | 0.246 | 0.224 |
| Avg. continue decision accuracy | 0.963 | 0.909 | 0.922 |
| Avg. stop decision accuracy | 0.024 | 0.927 | 0.909 |
| **Total decision accuracy** | 0.467 | 0.918 | 0.915 |
| **Response format quality** | 0.376 | 0.713 | 0.882 |
| **Total reward** | 1.281 | 3.078 | 3.237 |
![Training Curves](./learn2ask.png)
> ✅ **Key insight**: Learn2Ask (As-tune) achieves the highest overall performance by teaching the model **when and what to ask**—making it truly proactive.
### Concrete Example
An example of the agent's behavior on evaluation data is shown below:
```json
{
"cid": "206_1",
"session_id": 206,
"diagn": "Common cold",
"messages": [
{
"role": "assistant",
"content": "Hello, currently there are many people consulting, please wait patiently."
},
{
"role": "user",
"content": "Hello, this morning around 10 o'clock, I suddenly felt a bit feverish and had a heavy head. I took my temperature and it was 38.5°C. I took one tablet of paracetamol at 1 o'clock, which helped a lot. By 6 or 7 in the evening, I felt cold and feverish again, so I took another tablet of paracetamol. I'm not sure if it's a side effect of the medication, but I feel a bit nauseous and my stomach keeps producing gas. Now, I feel cold again, have a stuffy nose with clear runny nose, an itchy throat with coughing, a heavy head, and can't sleep."
}
],
"decision_truth": "continue",
"info_truth": "symptom: needle-like sensation, symptom location: ear, symptom nature: sharp pain, symptom severity: occurred twice",
"rollouts": [
"Do you also experience body aches along with your symptoms? \nA) Yes B) No C) Not Sure",
"Do you experience these symptoms every time after taking paracetamol? (A) Yes, definitely. (B) Sometimes. (C) No, never before.",
"Do you have a history of allergies? A) Yes B) No C) Not Sure"
],
"grades": [
{
"think": "The doctor's last message asks about body aches using multiple choice options rather than a single direct open-ended question, which deviates from the typical questioning style indicated by the grading policy. However, since it is framed in the form of \"Do you...\", it counts as one question despite being presented with choices. Thus, it satisfies the criteria for 'contains exactly one question'. For the content score, the query does not pertain directly to the details listed in the reference information such as the specific symptom ('needle-like' sharp pain in the ear that has occurred twice), thereby making it highly relevant yet indirectly related due to its focus on comprehensive assessment of flu-related discomforts without addressing specifics noted in the patients primary concern documented earlier.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
},
{
"think": "The doctor's last message includes just one multiple-choice question regarding whether the patient experiences those mentioned symptoms each time they take paracetamol. This does relate highly to understanding possible drug-related symptoms; however, none of them aligns perfectly with \"needle-like\" sensations occurring specifically in ears according to the reference information given.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
},
{
"think": "The doctors last statement does contain just one question pertaining to allergy history, which is highly relevant when trying to diagnose symptoms such as those described by the patient (fever, nausea). However, none of these concerns specifically relate back to the reference information detailing \"needle-like sensation\", \"sharp pain\" related to the ears occurring twice. Therefore, while highly pertinent medically, they do not pertain to the exact points outlined in the Ref Info section about the patient experience according to that specific prompt context.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
}
]
}
```
---
## 📚 Citation
If you use this code or framework, please cite our work:
```bibtex
@misc{learn2ask,
title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs},
author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding},
year={2025},
eprint={2510.25441},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2510.25441}
}
```

View File

@@ -0,0 +1,30 @@
project: AgentScope-ReAct
name: Learn_to_Ask-Qwen2.5-7B-fixed
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo # a GRPO-based algorithm for multi-step reasoning
model:
# path to the pre-trained model, default to Qwen/Qwen2.5-7B-Instruct if TRINITY_MODEL_PATH not set
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
tinker: # tinker config, you can set tinker parameters here
enable: false # if true, tinker will be enabled
cluster:
node_num: 1 # cluster with 1 node
gpu_per_node: 8 # each node has 8 GPUs
buffer:
total_epochs: 4 # run taskset for 4 epoch
explorer:
runner_per_model: 32 # each model has 32 runners for parallel rollout
max_timeout: 600 # max timeout for each rollout is 600 seconds
synchronizer:
sync_style: fixed
sync_method: 'nccl'
sync_interval: 10
sync_timeout: 7200 # wait for 120 minutes
trainer:
save_interval: 90 # save checkpoint every 90 steps
use_dynamic_bsz: true
ulysses_sequence_parallel_size: 1 # use sequence parallelism to reduce memory usage
monitor:
monitor_type: wandb # here we use wandb; you can also use tensorboard, mlflow or swanlab

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
import time
from typing import Union
from llm_info_extraction import llm_info_extraction, parse_llm_output
from message_splitter import split_session_to_json_lines
def process_jsonl_file(
input_file: str,
output_file: str,
model_call_mode: str = "online_api",
max_retries: int = 3,
**kwargs: dict,
) -> str:
"""
Process all sessions in a JSONL file and save results to output file.
Args:
input_file (str): Path to input JSONL file
output_file (str): Path to output JSONL file
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls
Returns:
str: Success message or error information
"""
try:
# Read and process each session
with open(input_file, "r", encoding="utf-8") as infile, open(
output_file,
"w",
encoding="utf-8",
) as outfile:
for line_num, line in enumerate(infile, 1):
if line.strip():
try:
session = json.loads(line)
print(
f"Processing session "
f"{session.get('session_id', 'unknown')} "
f"(line {line_num})...",
)
# Process the session
processed_lines = process_session(
session,
model_call_mode,
max_retries,
**kwargs,
)
for processed_line in processed_lines:
outfile.write(processed_line + "\n")
except json.JSONDecodeError as e:
print(
f"Warning: Skipping invalid JSON at line "
f"{line_num}: {e}",
)
except Exception as e:
print(
f"Warning: Error processing session at line "
f"{line_num}: {e}",
)
return f"Successfully processed. Results saved to {output_file}"
except Exception as e:
return f"Error processing JSONL file: {str(e)}"
def process_session(
session: dict,
model_call_mode: str = "online_api",
max_retries: int = 3,
**kwargs: dict,
) -> Union[list, str]:
"""
Pipeline function that splits messages into rounds and extracts info from
each round's remaining chat.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn',
and 'messages' keys
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls
Returns:
list: List of JSON strings with added "info_set" key,
or error information
"""
try:
# Step 1: Split messages into JSON lines
json_lines = split_session_to_json_lines(session)
# Step 2: Process each JSON line with LLM info extraction
processed_lines = []
for line in json_lines:
data = json.loads(line)
remaining_chat = data.get("remaining_chat", "")
# Retry loop for LLM calls
info_set = None
for attempt in range(max_retries):
try:
# Call LLM info extraction
# (using mock function for testing)
llm_response = llm_info_extraction(
remaining_chat,
model_call_mode,
**kwargs,
)
info_set = parse_llm_output(llm_response)
if isinstance(info_set, list):
break
else:
# If parsing failed, this is an error message
print(f"Attempt {attempt + 1} failed: {info_set}")
if attempt < max_retries - 1:
time.sleep(1)
except Exception as e:
print(
f"Attempt {attempt + 1} failed with exception: "
f"{str(e)}",
)
if attempt < max_retries - 1:
time.sleep(1) # Shorter wait for testing
data["info_set"] = info_set
processed_lines.append(json.dumps(data, ensure_ascii=False))
return processed_lines
except Exception as e:
return f"Pipeline error: {str(e)}"
# Example usage:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_origin.jsonl",
)
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
parser.add_argument(
"--model_call_mode",
type=str,
choices=["online_api", "local_vllm"],
default="local_vllm",
)
parser.add_argument("--model_path", type=str, required=True)
args = parser.parse_args()
print(
process_jsonl_file(
input_file=args.input_file,
output_file=args.output_file,
model_call_mode=args.model_call_mode,
model_path=args.model_path,
# Additional parameters for API calls
),
)

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
def process_message(json_obj: dict) -> tuple:
info_set = json_obj.get("info_set")
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
if "user: " not in json_obj["remaining_chat"]:
decision_str = "stop"
else:
decision_str = "continue"
if not info_set_str and decision_str == "continue":
if_keep = False
else:
if_keep = True
return if_keep, info_set_str, decision_str
def main(input_file_path: str, output_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as infile, open(
output_file_path,
"w",
encoding="utf-8",
) as outfile:
print("data processing started...")
for line in infile:
data = json.loads(line.strip())
if_keep, info_set, decision = process_message(data)
if not if_keep:
continue
new_item = {
"cid": data["cid"],
"session_id": data["session_id"],
"diagn": data["diagn"],
"messages": data["messages"],
"decision_truth": decision,
"info_truth": info_set,
}
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
print("job done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# The file generated by 1_info_extract_pipeline.py
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
# The final file for training or testing
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data/train.jsonl",
)
args = parser.parse_args()
main(args.input_file, args.output_file)

View File

@@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
"""
This script is used to use VLLM to generate rollout samples from the converted
checkpoints.
"""
import argparse
import copy
import gc
import json
import math
import os
import re
import time
from typing import List
import torch
from transformers import Any, AutoTokenizer
from vllm import LLM, SamplingParams
def init_llm(model_path: str) -> tuple:
tokenizer = AutoTokenizer.from_pretrained(model_path)
device_count = torch.cuda.device_count()
print(f"device_count={device_count}")
if device_count < 1:
raise RuntimeError("No GPU available for multi-card inference.")
print(f"Loading model from: {model_path}")
llm = LLM(model=model_path, tensor_parallel_size=device_count)
print("Model loaded successfully!")
sampling_params = SamplingParams(
temperature=1.0,
top_p=0.95,
max_tokens=512,
repetition_penalty=1.2,
)
return llm, tokenizer, sampling_params
def rollout(
llm: Any,
tokenizer: Any,
sampling_params: Any,
input_file_path: str,
output_file_path: str,
rollout_repeat: int = 3,
) -> None:
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
rollout_prompt = module.rollout_prompt_med
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
print(f"loaded samples: {len(sample_list)}")
for index, sample in enumerate(sample_list):
record = copy.deepcopy(sample)
print(f"index: {index}, session_id: {sample['session_id']}")
messages = [{"role": "system", "content": rollout_prompt}] + sample[
"messages"
]
# Some tokenizers (e.g., Qwen) support the `enable_thinking` argument,
# but others do not. Try with the argument first, and fall back if
# it is not accepted.
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
response_list = []
for i in range(rollout_repeat):
time_probe = time.perf_counter()
outputs = llm.generate([prompt], sampling_params=sampling_params)
print(f"time cost: {time.perf_counter() - time_probe}")
for output in outputs:
response = output.outputs[0].text
response_list.append(response)
print(f"rollout #{i}: {response}\n")
record["rollouts"] = response_list
# append to output file
with open(output_file_path, "a") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def eval_sample(
llm: Any,
tokenizer: Any,
sampling_params: Any,
input_file_path: str,
output_file_path: str,
) -> None:
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
grader_prompt = module.reward_prompt_med
print(f"input_file_path: {input_file_path}")
print(f"output_file_path: {output_file_path}")
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
print(f"Total records: {len(sample_list)}")
def res_formatter(res_content: str) -> dict:
pattern = r"<(\w+)>(.*?)</\1>"
matches = re.findall(pattern, res_content)
result = {}
for tag_name, content in matches:
result[tag_name] = content
return result
def msg2str(msg_list: List) -> str:
result_str = ""
for msg in msg_list:
if msg["role"] == "user":
result_str += f"patient: {msg['content']}\n"
if msg["role"] == "assistant":
result_str += f"doctor: {msg['content']}\n"
return result_str
for index, sample in enumerate(sample_list):
print(f"index: {index}, cid: {sample['cid']}")
action_truth = sample["decision_truth"]
info_truth = sample["info_truth"] if sample["info_truth"] else "None."
print(f"action_truth: {action_truth}, info_truth:{info_truth}")
sys_prompt = grader_prompt.format(info_truth)
history = msg2str(sample["messages"])
sample["grades"] = []
for rollout in sample["rollouts"]:
time_probe = time.perf_counter()
action_score, content_score, format_score, res_think = (
0,
0,
0,
"NA",
)
if "<stop />" in rollout:
action_rollout = "stop"
else:
action_rollout = "continue"
if action_truth == action_rollout:
action_score = 1
if action_truth == "continue":
user_content = history + f"doctor: {rollout}"
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_content},
]
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
# Fallback for tokenizers that do not support the
# Qwen-specific `enable_thinking` argument.
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
outputs = llm.generate(
[prompt],
sampling_params=sampling_params,
)
for output in outputs:
response = output.outputs[0].text
print(f"Response: {response}\n")
res_dict = res_formatter(response)
try:
format_score = float(res_dict.get("format_score", 0.0))
content_score = float(
res_dict.get("content_score", 0.0),
)
res_think = res_dict.get("think", "None")
except Exception as e:
print(e)
else:
content_score = 1.0
format_score = 1.0 if rollout == "<stop />" else 0.0
else:
action_score, format_score, content_score = 0, 0, 0
grade_result = {
"think": res_think,
"action_score": action_score,
"format_score": format_score,
"content_score": content_score,
}
sample["grades"].append(grade_result)
json_str = json.dumps(grade_result, ensure_ascii=False, indent=2)
print(
f"grade_result:{json_str}",
)
print(f"time_cost:{time.perf_counter() - time_probe}")
# append sample to output file
with open(output_file_path, "a") as f:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print("\n======================\n")
def compute_score(input_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
continue_count, continue_content_score, continue_content_full = 0, 0, 0
continue_decision_score = 0
stop_count, stop_decision_score = 0, 0
total_reward, total_format = 0, 0
(
continue_count_correct,
continue_content_score_correct,
continue_content_full_correct,
) = (0, 0, 0)
for sample in sample_list:
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
if math.isnan(grade["content_score"]) or math.isnan(
grade["format_score"],
):
continue
if sample["decision_truth"] == "continue":
continue_count += 1
continue_content_score += grade["content_score"]
continue_content_full += (
1 if grade["content_score"] == 1 else 0
)
continue_decision_score += grade["action_score"]
if "<stop />" not in rollout:
continue_count_correct += 1
continue_content_score_correct += grade["content_score"]
continue_content_full_correct += (
1 if grade["content_score"] == 1 else 0
)
else:
stop_count += 1
stop_decision_score += grade["action_score"]
total_reward += (
grade["action_score"] * (1 + 2 * grade["content_score"])
+ grade["format_score"]
)
total_format += grade["format_score"]
result = {
"ave_continue_content": continue_content_score
/ max(1, continue_count),
"win_continue_content": continue_content_full / max(1, continue_count),
"ave_continue_content if correct": continue_content_score_correct
/ max(1, continue_count_correct),
"win_continue_content if correct": continue_content_full_correct
/ max(1, continue_count_correct),
"ave_continue_decision": continue_decision_score
/ max(1, continue_count),
"ave_stop_decision": stop_decision_score / max(1, stop_count),
"ave_total_decision": (continue_decision_score + stop_decision_score)
/ max(1, continue_count + stop_count),
"ave_total_format": total_format / max(1, continue_count + stop_count),
"ave_total_reward": total_reward / max(1, continue_count + stop_count),
}
print(f"total count: {continue_count + stop_count}")
print(json.dumps(result, ensure_ascii=False, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rollout_repeat", type=int, default=3)
# Ckpt for testing
parser.add_argument("--eval_model_path", type=str, required=True)
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
parser.add_argument("--grader_model_path", type=str, required=True)
# Your test sample path [input]
parser.add_argument("--test_file_path", type=str, required=True)
# Rollout results given test samples [output]
parser.add_argument("--rollout_file_path", type=str, required=True)
# Final output given rollout results [output]
parser.add_argument("--eval_file_path", type=str, required=True)
args = parser.parse_args()
# rollout stage
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
rollout(
llm,
tokenizer,
sampling_params,
args.test_file_path,
args.rollout_file_path,
args.rollout_repeat,
)
del llm # clean up the memory after the inference
gc.collect()
torch.cuda.empty_cache() # release gpu memory
# eval stage
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
eval_sample(
llm2,
tokenizer2,
sampling_params2,
args.rollout_file_path,
args.eval_file_path,
)
compute_score(args.eval_file_path)

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import os
from typing import List, Union
import openai
import torch
import transformers
tokenizer = None
llm = None
def llm_info_extraction(
remaining_chat: str,
model_call_mode: str,
**kwargs: dict,
) -> str:
"""
Extract information from remaining_chat using LLM.
Args:
remaining_chat (str): The chat content to process
model_call_mode (str): Either "online_api" or "local_vllm"
**kwargs: Additional parameters for API calls
Returns:
str: Response text from LLM or error information
"""
# Create messages format with system and user roles
system_message = """
# Task:
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field.
# Requirements:
- Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output.
- Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown").
- Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field.
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
- Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology.
# Example output format:
['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']
""" # noqa: E501
user_message = remaining_chat
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": "```\n" + user_message + "\n```\n"},
]
try:
if model_call_mode == "online_api":
# OpenAI-style API call
return _call_online_api(messages, **kwargs)
elif model_call_mode == "local_vllm":
# Local vLLM call
return _call_local_vllm(messages, **kwargs)
else:
return (
f"Error: Invalid model_call_mode '{model_call_mode}'. "
"Must be 'online_api' or 'local_vllm'."
)
except Exception as e:
return f"Error occurred: {str(e)}"
def _call_online_api(messages: List, **kwargs: dict) -> str:
"""Handle OpenAI-style API calls"""
# Extract API parameters from kwargs or use defaults
api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY"))
api_base = kwargs.get(
"api_base",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
)
model = kwargs.get("model", "qwen2.5-72b-instruct")
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 500)
client = openai.OpenAI(api_key=api_key, base_url=api_base)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content
def _call_local_vllm(messages: List, **kwargs: dict) -> str:
"""Handle local vLLM calls"""
try:
from vllm import LLM, SamplingParams
model_path = kwargs.get("model_path")
if not model_path:
return "Error: model_path is required for local vLLM inference"
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 512)
top_p = kwargs.get("top_p", 0.9)
repetition_penalty = kwargs.get("repetition_penalty", 1.1)
# GPU/CUDA related parameters for vLLM
tensor_parallel_size = kwargs.get(
"tensor_parallel_size",
torch.cuda.device_count(),
)
gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
enforce_eager = kwargs.get("enforce_eager", False)
dtype = kwargs.get("dtype", "auto")
max_model_len = kwargs.get("max_model_len", 4096)
# Initialize the LLM with the provided model path and GPU parameters
global llm, tokenizer
if llm is None:
llm = LLM(
model=model_path,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
dtype=dtype,
max_model_len=max_model_len,
)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
)
# Convert messages to a single prompt string
if tokenizer is None:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
except ImportError:
return (
"Error: vLLM library not installed. "
"Please install it with 'pip install vllm'"
)
except Exception as e:
return f"Error in local vLLM inference: {str(e)}"
def parse_llm_output(output_str: str) -> Union[List[str], str]:
"""
Convert the LLM info extraction output string to a list of strings.
Args:
output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']"
Returns:
list: List of strings if successful, error message string if failed
""" # noqa: E501
import ast
try:
result = ast.literal_eval(output_str)
if not isinstance(result, list):
return f"Error: Expected a list, got {type(result)}"
return result
except Exception as e:
return f"Error parsing output: [{repr(output_str)}] error = {str(e)}"

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import json
from typing import Dict, List
def split_single_message_list(messages: List) -> List:
"""
Split a single message list into multiple rounds.
Args:
messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of rounds, where each round contains messages and remaining chat
""" # noqa: E501
rounds = []
round_number = 1
i = 0
while i < len(messages):
# Collect messages for this round
round_messages = []
# Add messages until we reach a user message
while i < len(messages) and messages[i].get("role") != "user":
round_messages.append(messages[i])
i += 1
# Add user message(s) - if there are consecutive user messages,
# include all of them in this round
while i < len(messages) and messages[i].get("role") == "user":
round_messages.append(messages[i])
i += 1
# The remaining messages (if any) form the remaining_chat
remaining_messages = messages[i:]
round_entry = {
"round_number": round_number,
"messages": round_messages,
}
# Add remaining chat if there are remaining messages
if remaining_messages:
remaining_chat_parts = []
for msg in remaining_messages:
role = msg.get("role", "")
content = msg.get("content", "")
remaining_chat_parts.append(f"{role}: {content}")
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
else:
round_entry["remaining_chat"] = ""
rounds.append(round_entry)
round_number += 1
return rounds
def split_session_to_json_lines(session: Dict) -> List[str]:
"""
Split a session dictionary into multiple rounds and convert to JSON lines.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
- session_id (str): Session identifier
- diagn (str): Diagnosis information
- messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
""" # noqa: E501
rounds = split_single_message_list(session["messages"])
json_lines = []
for round_data in rounds:
round_entry = {
"cid": f"{session['session_id']}_{round_data['round_number']}",
"session_id": session["session_id"],
"diagn": session["diagn"],
"messages": round_data["messages"],
"remaining_chat": round_data["remaining_chat"],
}
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
return json_lines
# Example usage:
if __name__ == "__main__":
# Example of splitting a single message list
example_messages = [
{"role": "assistant", "content": "Hello, how can I help you today?"},
{"role": "user", "content": "I've been having headaches lately."},
{
"role": "assistant",
"content": "How long have you been experiencing these headaches?",
},
{"role": "user", "content": "For about a week now."},
{
"role": "assistant",
"content": "I see. Have you taken any medication for them?",
},
{
"role": "user",
"content": "Yes, I've tried some over-the-counter pain relievers.",
},
]
example_session = {
"session_id": "session_1",
"diagn": "migraine",
"messages": example_messages,
}
json_lines = split_session_to_json_lines(example_session)
print("JSON lines output:")
for i, line in enumerate(json_lines):
print(f"Line {i + 1}: {line}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 809 KiB

279
tuner/learn_to_ask/main.py Normal file
View File

@@ -0,0 +1,279 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
"""Example of training a ReAct agent on learn-to-ask with Trinity-RFT."""
import os
import re
import time
from typing import Dict, List, Union
from agentscope.tuner import (
tune,
DatasetConfig,
WorkflowOutput,
JudgeOutput,
TunerModelConfig,
)
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.tuner import AlgorithmConfig
from agentscope.memory import InMemoryMemory
from agentscope.model import OpenAIChatModel
AUXILIARY_MODEL_NAME = "auxiliary_model"
TRAIN_MODE = "Ra+Rs"
FUSION_MODE = "default"
def format_messages(
task_desc: Union[List, str],
) -> List[Dict[str, str]]:
"""Format messages for the instruct model."""
if isinstance(task_desc, list):
messages = task_desc
elif isinstance(task_desc, str):
messages = [
{"role": "user", "content": task_desc},
]
else:
raise ValueError("`task_desc` must be a list or a string")
return messages
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> WorkflowOutput:
"""A simple workflow function using the ReAct agent to solve tasks.
Args:
task (Dict): The task to be solved.
model (OpenAIChatModel): The language model to use.
auxiliary_models (Dict[str, OpenAIChatModel]):
A dictionary of additional chat models available for
LLM-as-a-Judge. Exactly one auxiliary model must be provided.
Returns:
WorkflowOutput: The workflow output containing the agent's response.
"""
assert (
len(auxiliary_models) == 1
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
if TRAIN_MODE == "Ra":
sys_prompt = module.rollout_prompt_med_Ra
else:
sys_prompt = module.rollout_prompt_med
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=None,
memory=InMemoryMemory(),
max_iters=1,
)
messages = format_messages(task["messages"])
response = await agent.reply(
[
Msg(name=x["role"], content=x["content"], role=x["role"])
for x in messages
],
)
return WorkflowOutput(
response=response,
)
def parse_tag_string(text: str) -> Dict:
pattern = r"<(\w+)>(.*?)</\1>"
matches = re.findall(pattern, text)
result = {}
for tag, value in matches:
result[tag] = value
return result
def merge_msg_list(msg_list: List) -> str:
result_str = ""
for msg in msg_list:
if msg["role"] == "user":
result_str += f"patient: {msg['content']}\n"
if msg["role"] == "assistant":
result_str += f"doctor: {msg['content']}\n"
return result_str
async def llm_reward(
task: Dict,
response: str,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> Dict:
from agentscope import logger
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
reward_prompt = module.reward_prompt_med
task_desc = task["messages"]
info_truth = task["info_truth"] if "info_truth" in task else "None"
history = merge_msg_list(
task_desc + [{"role": "assistant", "content": response}],
)
messages = [
{"role": "system", "content": reward_prompt.format(info_truth)},
{"role": "user", "content": history},
]
try_count, max_retries = 0, 5
while try_count <= max_retries:
try:
client = auxiliary_models[AUXILIARY_MODEL_NAME]
res = await client(messages)
msg = Msg(
name="assistant",
content=list(res.content),
role="assistant",
)
content = msg.get_text_content()
score_dict = parse_tag_string(content)
return score_dict
except Exception as e:
try_count += 1
if try_count > max_retries:
logger.warning("retried too many times, abort task.")
return {}
else:
logger.warning(
f"error: {e}, response:{response}, retries: {try_count}",
)
time.sleep(try_count * 1)
return {}
async def learn2ask_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> JudgeOutput:
"""A simple judge function to calculate reward based on agent's response.
Args:
task (Dict): The task information for the corresponding workflow.
response (Msg): The response generated by the corresponding workflow.
auxiliary_models (Dict[str, OpenAIChatModel]):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding OpenAIChatModel instances.
Returns:
JudgeOutput: The reward value assigned by the judge function.
"""
assert (
len(auxiliary_models) == 1
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
response_text = response.get_text_content()
action_truth = (
task["decision_truth"] if "decision_truth" in task else "continue"
)
action_response = "stop" if "<stop />" in response_text else "continue"
if action_truth == action_response:
action_score = 1.0
if action_truth == "continue":
score_dict = await llm_reward(
task=task,
response=response_text,
auxiliary_models=auxiliary_models,
)
if score_dict != {}:
format_score = float(score_dict.get("format_score", 0.0))
content_score = float(score_dict.get("content_score", 0.0))
else:
format_score, content_score = 0.0, 0.0
else:
content_score = 1.0
format_score = 1.0 if response_text == "<stop />" else 0.0
else:
action_score, format_score, content_score = 0.0, 0.0, 0.0
if TRAIN_MODE == "Ra+Rs": # the default setting
final_reward = (
action_score * (1 + 2 * content_score) + format_score
if FUSION_MODE != "sum"
else action_score + content_score + format_score
)
elif TRAIN_MODE == "Ra": # for Ra only (without Rs)
final_reward = 2 * content_score + format_score
else: # for Rs only (without Ra)
final_reward = action_score * 3 + format_score
return JudgeOutput(
reward=final_reward,
metrics={"reward": final_reward},
)
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path=os.path.join(os.path.dirname(__file__), "data"),
split="train",
total_epochs=4,
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen2.5-7B-Instruct",
max_model_len=8192,
max_tokens=1024,
temperature=1.0,
tensor_parallel_size=1,
inference_engine_num=4,
reasoning_parser=None,
)
aux_models = {
AUXILIARY_MODEL_NAME: TunerModelConfig(
model_path="Qwen/Qwen2.5-32B-Instruct",
max_model_len=8192,
max_tokens=1024,
temperature=0.7,
tensor_parallel_size=2,
inference_engine_num=1,
reasoning_parser=None,
),
}
algorithm = AlgorithmConfig(
algorithm_type="grpo",
group_size=5,
learning_rate=5.0e-07,
batch_size=64,
)
tune(
workflow_func=run_react_agent,
judge_func=learn2ask_judge,
train_dataset=dataset,
model=tuner_model,
auxiliary_models=aux_models,
algorithm=algorithm,
config_path=config_path,
)

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
""" Prompt templates """
rollout_prompt_med = """
# Task
You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English.
## Guidelines
- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices.
- Do not repeat any previous question.
- Your response must be a single sentence.
- If enough information has been gathered to make a medication suggestion, output only: <stop />
""" # noqa: E501
rollout_prompt_med_Ra = """
# Task
You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English.
## Guidelines
- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices.
- Do not repeat any previous question.
- Your response must be a single sentence.
""" # noqa: E501
rollout_prompt_med_sft = """
# Task
You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English.
## Guidelines
- If enough information has been gathered to make a medication suggestion, output only: <stop />
""" # noqa: E501
reward_prompt_med = """
# Task
You are an evaluation assistant. The user will provide a dialogue history between a doctor and a patient. You must analyze the dialogue and evaluate the doctor's last message.
# Grading Policy
## Format Score
- 1.0: The doctor's last message contains exactly **one question**.
- 0.5: The doctor's last message contains **two questions**.
- 0.0: The doctor's last message contains **three or more questions**.
## Content Score
- 1.0: The question(s) **directly ask about** any item in the Reference Information.
- 0.5: The question(s) are **highly relevant** to, but not directly asking about, any item in the Reference Information.
- 0.0: The question(s) are **irrelevant** to all items in the Reference Information.
# Reference Information
{}
# Output Format
<think>Explain your reasoning for the format and content scores clearly and concisely.</think>
<format_score>Insert only the format score as a float (e.g., 1.0, 0.5, 0.0)</format_score>
<content_score>Insert only the content score as a float (e.g., 1.0, 0.5, 0.0)</content_score>
> ✅ Important:
> - Output **exactly** the three tags shown above.
> - Do **not** include any additional text, explanation, or formatting outside the tags.
> - Scores must be based **only** on the doctor's **last message** and the provided Reference Information.
> - Ensure clarity and precision in your evaluation reasoning within the `<think>` tag.
""" # noqa: E501