Add Tuner learn_to_ask example (#101)

This commit is contained in:
chenyushuo
2026-01-16 19:24:46 +08:00
committed by GitHub
parent 5855c5161b
commit 3821fb04ac
10 changed files with 1643 additions and 0 deletions

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
import time
from typing import Union
from llm_info_extraction import llm_info_extraction, parse_llm_output
from message_splitter import split_session_to_json_lines
def process_jsonl_file(
input_file: str,
output_file: str,
model_call_mode: str = "online_api",
max_retries: int = 3,
**kwargs: dict,
) -> str:
"""
Process all sessions in a JSONL file and save results to output file.
Args:
input_file (str): Path to input JSONL file
output_file (str): Path to output JSONL file
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls
Returns:
str: Success message or error information
"""
try:
# Read and process each session
with open(input_file, "r", encoding="utf-8") as infile, open(
output_file,
"w",
encoding="utf-8",
) as outfile:
for line_num, line in enumerate(infile, 1):
if line.strip():
try:
session = json.loads(line)
print(
f"Processing session "
f"{session.get('session_id', 'unknown')} "
f"(line {line_num})...",
)
# Process the session
processed_lines = process_session(
session,
model_call_mode,
max_retries,
**kwargs,
)
for processed_line in processed_lines:
outfile.write(processed_line + "\n")
except json.JSONDecodeError as e:
print(
f"Warning: Skipping invalid JSON at line "
f"{line_num}: {e}",
)
except Exception as e:
print(
f"Warning: Error processing session at line "
f"{line_num}: {e}",
)
return f"Successfully processed. Results saved to {output_file}"
except Exception as e:
return f"Error processing JSONL file: {str(e)}"
def process_session(
session: dict,
model_call_mode: str = "online_api",
max_retries: int = 3,
**kwargs: dict,
) -> Union[list, str]:
"""
Pipeline function that splits messages into rounds and extracts info from
each round's remaining chat.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn',
and 'messages' keys
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls
Returns:
list: List of JSON strings with added "info_set" key,
or error information
"""
try:
# Step 1: Split messages into JSON lines
json_lines = split_session_to_json_lines(session)
# Step 2: Process each JSON line with LLM info extraction
processed_lines = []
for line in json_lines:
data = json.loads(line)
remaining_chat = data.get("remaining_chat", "")
# Retry loop for LLM calls
info_set = None
for attempt in range(max_retries):
try:
# Call LLM info extraction
# (using mock function for testing)
llm_response = llm_info_extraction(
remaining_chat,
model_call_mode,
**kwargs,
)
info_set = parse_llm_output(llm_response)
if isinstance(info_set, list):
break
else:
# If parsing failed, this is an error message
print(f"Attempt {attempt + 1} failed: {info_set}")
if attempt < max_retries - 1:
time.sleep(1)
except Exception as e:
print(
f"Attempt {attempt + 1} failed with exception: "
f"{str(e)}",
)
if attempt < max_retries - 1:
time.sleep(1) # Shorter wait for testing
data["info_set"] = info_set
processed_lines.append(json.dumps(data, ensure_ascii=False))
return processed_lines
except Exception as e:
return f"Pipeline error: {str(e)}"
# Example usage:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_origin.jsonl",
)
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
parser.add_argument(
"--model_call_mode",
type=str,
choices=["online_api", "local_vllm"],
default="local_vllm",
)
parser.add_argument("--model_path", type=str, required=True)
args = parser.parse_args()
print(
process_jsonl_file(
input_file=args.input_file,
output_file=args.output_file,
model_call_mode=args.model_call_mode,
model_path=args.model_path,
# Additional parameters for API calls
),
)

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
def process_message(json_obj: dict) -> tuple:
info_set = json_obj.get("info_set")
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
if "user: " not in json_obj["remaining_chat"]:
decision_str = "stop"
else:
decision_str = "continue"
if not info_set_str and decision_str == "continue":
if_keep = False
else:
if_keep = True
return if_keep, info_set_str, decision_str
def main(input_file_path: str, output_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as infile, open(
output_file_path,
"w",
encoding="utf-8",
) as outfile:
print("data processing started...")
for line in infile:
data = json.loads(line.strip())
if_keep, info_set, decision = process_message(data)
if not if_keep:
continue
new_item = {
"cid": data["cid"],
"session_id": data["session_id"],
"diagn": data["diagn"],
"messages": data["messages"],
"decision_truth": decision,
"info_truth": info_set,
}
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
print("job done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# The file generated by 1_info_extract_pipeline.py
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
# The final file for training or testing
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data/train.jsonl",
)
args = parser.parse_args()
main(args.input_file, args.output_file)

View File

@@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
"""
This script is used to use VLLM to generate rollout samples from the converted
checkpoints.
"""
import argparse
import copy
import gc
import json
import math
import os
import re
import time
from typing import List
import torch
from transformers import Any, AutoTokenizer
from vllm import LLM, SamplingParams
def init_llm(model_path: str) -> tuple:
tokenizer = AutoTokenizer.from_pretrained(model_path)
device_count = torch.cuda.device_count()
print(f"device_count={device_count}")
if device_count < 1:
raise RuntimeError("No GPU available for multi-card inference.")
print(f"Loading model from: {model_path}")
llm = LLM(model=model_path, tensor_parallel_size=device_count)
print("Model loaded successfully!")
sampling_params = SamplingParams(
temperature=1.0,
top_p=0.95,
max_tokens=512,
repetition_penalty=1.2,
)
return llm, tokenizer, sampling_params
def rollout(
llm: Any,
tokenizer: Any,
sampling_params: Any,
input_file_path: str,
output_file_path: str,
rollout_repeat: int = 3,
) -> None:
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
rollout_prompt = module.rollout_prompt_med
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
print(f"loaded samples: {len(sample_list)}")
for index, sample in enumerate(sample_list):
record = copy.deepcopy(sample)
print(f"index: {index}, session_id: {sample['session_id']}")
messages = [{"role": "system", "content": rollout_prompt}] + sample[
"messages"
]
# Some tokenizers (e.g., Qwen) support the `enable_thinking` argument,
# but others do not. Try with the argument first, and fall back if
# it is not accepted.
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
response_list = []
for i in range(rollout_repeat):
time_probe = time.perf_counter()
outputs = llm.generate([prompt], sampling_params=sampling_params)
print(f"time cost: {time.perf_counter() - time_probe}")
for output in outputs:
response = output.outputs[0].text
response_list.append(response)
print(f"rollout #{i}: {response}\n")
record["rollouts"] = response_list
# append to output file
with open(output_file_path, "a") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def eval_sample(
llm: Any,
tokenizer: Any,
sampling_params: Any,
input_file_path: str,
output_file_path: str,
) -> None:
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
grader_prompt = module.reward_prompt_med
print(f"input_file_path: {input_file_path}")
print(f"output_file_path: {output_file_path}")
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
print(f"Total records: {len(sample_list)}")
def res_formatter(res_content: str) -> dict:
pattern = r"<(\w+)>(.*?)</\1>"
matches = re.findall(pattern, res_content)
result = {}
for tag_name, content in matches:
result[tag_name] = content
return result
def msg2str(msg_list: List) -> str:
result_str = ""
for msg in msg_list:
if msg["role"] == "user":
result_str += f"patient: {msg['content']}\n"
if msg["role"] == "assistant":
result_str += f"doctor: {msg['content']}\n"
return result_str
for index, sample in enumerate(sample_list):
print(f"index: {index}, cid: {sample['cid']}")
action_truth = sample["decision_truth"]
info_truth = sample["info_truth"] if sample["info_truth"] else "None."
print(f"action_truth: {action_truth}, info_truth:{info_truth}")
sys_prompt = grader_prompt.format(info_truth)
history = msg2str(sample["messages"])
sample["grades"] = []
for rollout in sample["rollouts"]:
time_probe = time.perf_counter()
action_score, content_score, format_score, res_think = (
0,
0,
0,
"NA",
)
if "<stop />" in rollout:
action_rollout = "stop"
else:
action_rollout = "continue"
if action_truth == action_rollout:
action_score = 1
if action_truth == "continue":
user_content = history + f"doctor: {rollout}"
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_content},
]
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
# Fallback for tokenizers that do not support the
# Qwen-specific `enable_thinking` argument.
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
outputs = llm.generate(
[prompt],
sampling_params=sampling_params,
)
for output in outputs:
response = output.outputs[0].text
print(f"Response: {response}\n")
res_dict = res_formatter(response)
try:
format_score = float(res_dict.get("format_score", 0.0))
content_score = float(
res_dict.get("content_score", 0.0),
)
res_think = res_dict.get("think", "None")
except Exception as e:
print(e)
else:
content_score = 1.0
format_score = 1.0 if rollout == "<stop />" else 0.0
else:
action_score, format_score, content_score = 0, 0, 0
grade_result = {
"think": res_think,
"action_score": action_score,
"format_score": format_score,
"content_score": content_score,
}
sample["grades"].append(grade_result)
json_str = json.dumps(grade_result, ensure_ascii=False, indent=2)
print(
f"grade_result:{json_str}",
)
print(f"time_cost:{time.perf_counter() - time_probe}")
# append sample to output file
with open(output_file_path, "a") as f:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print("\n======================\n")
def compute_score(input_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
continue_count, continue_content_score, continue_content_full = 0, 0, 0
continue_decision_score = 0
stop_count, stop_decision_score = 0, 0
total_reward, total_format = 0, 0
(
continue_count_correct,
continue_content_score_correct,
continue_content_full_correct,
) = (0, 0, 0)
for sample in sample_list:
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
if math.isnan(grade["content_score"]) or math.isnan(
grade["format_score"],
):
continue
if sample["decision_truth"] == "continue":
continue_count += 1
continue_content_score += grade["content_score"]
continue_content_full += (
1 if grade["content_score"] == 1 else 0
)
continue_decision_score += grade["action_score"]
if "<stop />" not in rollout:
continue_count_correct += 1
continue_content_score_correct += grade["content_score"]
continue_content_full_correct += (
1 if grade["content_score"] == 1 else 0
)
else:
stop_count += 1
stop_decision_score += grade["action_score"]
total_reward += (
grade["action_score"] * (1 + 2 * grade["content_score"])
+ grade["format_score"]
)
total_format += grade["format_score"]
result = {
"ave_continue_content": continue_content_score
/ max(1, continue_count),
"win_continue_content": continue_content_full / max(1, continue_count),
"ave_continue_content if correct": continue_content_score_correct
/ max(1, continue_count_correct),
"win_continue_content if correct": continue_content_full_correct
/ max(1, continue_count_correct),
"ave_continue_decision": continue_decision_score
/ max(1, continue_count),
"ave_stop_decision": stop_decision_score / max(1, stop_count),
"ave_total_decision": (continue_decision_score + stop_decision_score)
/ max(1, continue_count + stop_count),
"ave_total_format": total_format / max(1, continue_count + stop_count),
"ave_total_reward": total_reward / max(1, continue_count + stop_count),
}
print(f"total count: {continue_count + stop_count}")
print(json.dumps(result, ensure_ascii=False, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rollout_repeat", type=int, default=3)
# Ckpt for testing
parser.add_argument("--eval_model_path", type=str, required=True)
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
parser.add_argument("--grader_model_path", type=str, required=True)
# Your test sample path [input]
parser.add_argument("--test_file_path", type=str, required=True)
# Rollout results given test samples [output]
parser.add_argument("--rollout_file_path", type=str, required=True)
# Final output given rollout results [output]
parser.add_argument("--eval_file_path", type=str, required=True)
args = parser.parse_args()
# rollout stage
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
rollout(
llm,
tokenizer,
sampling_params,
args.test_file_path,
args.rollout_file_path,
args.rollout_repeat,
)
del llm # clean up the memory after the inference
gc.collect()
torch.cuda.empty_cache() # release gpu memory
# eval stage
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
eval_sample(
llm2,
tokenizer2,
sampling_params2,
args.rollout_file_path,
args.eval_file_path,
)
compute_score(args.eval_file_path)

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import os
from typing import List, Union
import openai
import torch
import transformers
tokenizer = None
llm = None
def llm_info_extraction(
remaining_chat: str,
model_call_mode: str,
**kwargs: dict,
) -> str:
"""
Extract information from remaining_chat using LLM.
Args:
remaining_chat (str): The chat content to process
model_call_mode (str): Either "online_api" or "local_vllm"
**kwargs: Additional parameters for API calls
Returns:
str: Response text from LLM or error information
"""
# Create messages format with system and user roles
system_message = """
# Task:
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field.
# Requirements:
- Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output.
- Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown").
- Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field.
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
- Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology.
# Example output format:
['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']
""" # noqa: E501
user_message = remaining_chat
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": "```\n" + user_message + "\n```\n"},
]
try:
if model_call_mode == "online_api":
# OpenAI-style API call
return _call_online_api(messages, **kwargs)
elif model_call_mode == "local_vllm":
# Local vLLM call
return _call_local_vllm(messages, **kwargs)
else:
return (
f"Error: Invalid model_call_mode '{model_call_mode}'. "
"Must be 'online_api' or 'local_vllm'."
)
except Exception as e:
return f"Error occurred: {str(e)}"
def _call_online_api(messages: List, **kwargs: dict) -> str:
"""Handle OpenAI-style API calls"""
# Extract API parameters from kwargs or use defaults
api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY"))
api_base = kwargs.get(
"api_base",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
)
model = kwargs.get("model", "qwen2.5-72b-instruct")
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 500)
client = openai.OpenAI(api_key=api_key, base_url=api_base)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content
def _call_local_vllm(messages: List, **kwargs: dict) -> str:
"""Handle local vLLM calls"""
try:
from vllm import LLM, SamplingParams
model_path = kwargs.get("model_path")
if not model_path:
return "Error: model_path is required for local vLLM inference"
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 512)
top_p = kwargs.get("top_p", 0.9)
repetition_penalty = kwargs.get("repetition_penalty", 1.1)
# GPU/CUDA related parameters for vLLM
tensor_parallel_size = kwargs.get(
"tensor_parallel_size",
torch.cuda.device_count(),
)
gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
enforce_eager = kwargs.get("enforce_eager", False)
dtype = kwargs.get("dtype", "auto")
max_model_len = kwargs.get("max_model_len", 4096)
# Initialize the LLM with the provided model path and GPU parameters
global llm, tokenizer
if llm is None:
llm = LLM(
model=model_path,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
dtype=dtype,
max_model_len=max_model_len,
)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
)
# Convert messages to a single prompt string
if tokenizer is None:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
except ImportError:
return (
"Error: vLLM library not installed. "
"Please install it with 'pip install vllm'"
)
except Exception as e:
return f"Error in local vLLM inference: {str(e)}"
def parse_llm_output(output_str: str) -> Union[List[str], str]:
"""
Convert the LLM info extraction output string to a list of strings.
Args:
output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']"
Returns:
list: List of strings if successful, error message string if failed
""" # noqa: E501
import ast
try:
result = ast.literal_eval(output_str)
if not isinstance(result, list):
return f"Error: Expected a list, got {type(result)}"
return result
except Exception as e:
return f"Error parsing output: [{repr(output_str)}] error = {str(e)}"

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import json
from typing import Dict, List
def split_single_message_list(messages: List) -> List:
"""
Split a single message list into multiple rounds.
Args:
messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of rounds, where each round contains messages and remaining chat
""" # noqa: E501
rounds = []
round_number = 1
i = 0
while i < len(messages):
# Collect messages for this round
round_messages = []
# Add messages until we reach a user message
while i < len(messages) and messages[i].get("role") != "user":
round_messages.append(messages[i])
i += 1
# Add user message(s) - if there are consecutive user messages,
# include all of them in this round
while i < len(messages) and messages[i].get("role") == "user":
round_messages.append(messages[i])
i += 1
# The remaining messages (if any) form the remaining_chat
remaining_messages = messages[i:]
round_entry = {
"round_number": round_number,
"messages": round_messages,
}
# Add remaining chat if there are remaining messages
if remaining_messages:
remaining_chat_parts = []
for msg in remaining_messages:
role = msg.get("role", "")
content = msg.get("content", "")
remaining_chat_parts.append(f"{role}: {content}")
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
else:
round_entry["remaining_chat"] = ""
rounds.append(round_entry)
round_number += 1
return rounds
def split_session_to_json_lines(session: Dict) -> List[str]:
"""
Split a session dictionary into multiple rounds and convert to JSON lines.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
- session_id (str): Session identifier
- diagn (str): Diagnosis information
- messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
""" # noqa: E501
rounds = split_single_message_list(session["messages"])
json_lines = []
for round_data in rounds:
round_entry = {
"cid": f"{session['session_id']}_{round_data['round_number']}",
"session_id": session["session_id"],
"diagn": session["diagn"],
"messages": round_data["messages"],
"remaining_chat": round_data["remaining_chat"],
}
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
return json_lines
# Example usage:
if __name__ == "__main__":
# Example of splitting a single message list
example_messages = [
{"role": "assistant", "content": "Hello, how can I help you today?"},
{"role": "user", "content": "I've been having headaches lately."},
{
"role": "assistant",
"content": "How long have you been experiencing these headaches?",
},
{"role": "user", "content": "For about a week now."},
{
"role": "assistant",
"content": "I see. Have you taken any medication for them?",
},
{
"role": "user",
"content": "Yes, I've tried some over-the-counter pain relievers.",
},
]
example_session = {
"session_id": "session_1",
"diagn": "migraine",
"messages": example_messages,
}
json_lines = split_session_to_json_lines(example_session)
print("JSON lines output:")
for i, line in enumerate(json_lines):
print(f"Line {i + 1}: {line}")