Add Tuner learn_to_ask example (#101)
This commit is contained in:
175
tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py
Normal file
175
tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: skip-file
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from typing import Union
|
||||
|
||||
from llm_info_extraction import llm_info_extraction, parse_llm_output
|
||||
from message_splitter import split_session_to_json_lines
|
||||
|
||||
|
||||
def process_jsonl_file(
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
model_call_mode: str = "online_api",
|
||||
max_retries: int = 3,
|
||||
**kwargs: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Process all sessions in a JSONL file and save results to output file.
|
||||
|
||||
Args:
|
||||
input_file (str): Path to input JSONL file
|
||||
output_file (str): Path to output JSONL file
|
||||
model_call_mode (str): Either "online_api" or "local_vllm"
|
||||
max_retries (int): Maximum number of retries for LLM calls
|
||||
**kwargs: Additional parameters for API calls
|
||||
|
||||
Returns:
|
||||
str: Success message or error information
|
||||
"""
|
||||
try:
|
||||
# Read and process each session
|
||||
with open(input_file, "r", encoding="utf-8") as infile, open(
|
||||
output_file,
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as outfile:
|
||||
for line_num, line in enumerate(infile, 1):
|
||||
if line.strip():
|
||||
try:
|
||||
session = json.loads(line)
|
||||
print(
|
||||
f"Processing session "
|
||||
f"{session.get('session_id', 'unknown')} "
|
||||
f"(line {line_num})...",
|
||||
)
|
||||
|
||||
# Process the session
|
||||
processed_lines = process_session(
|
||||
session,
|
||||
model_call_mode,
|
||||
max_retries,
|
||||
**kwargs,
|
||||
)
|
||||
for processed_line in processed_lines:
|
||||
outfile.write(processed_line + "\n")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(
|
||||
f"Warning: Skipping invalid JSON at line "
|
||||
f"{line_num}: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Error processing session at line "
|
||||
f"{line_num}: {e}",
|
||||
)
|
||||
|
||||
return f"Successfully processed. Results saved to {output_file}"
|
||||
|
||||
except Exception as e:
|
||||
return f"Error processing JSONL file: {str(e)}"
|
||||
|
||||
|
||||
def process_session(
|
||||
session: dict,
|
||||
model_call_mode: str = "online_api",
|
||||
max_retries: int = 3,
|
||||
**kwargs: dict,
|
||||
) -> Union[list, str]:
|
||||
"""
|
||||
Pipeline function that splits messages into rounds and extracts info from
|
||||
each round's remaining chat.
|
||||
|
||||
Args:
|
||||
session (dict): Session dictionary containing 'session_id', 'diagn',
|
||||
and 'messages' keys
|
||||
model_call_mode (str): Either "online_api" or "local_vllm"
|
||||
max_retries (int): Maximum number of retries for LLM calls
|
||||
**kwargs: Additional parameters for API calls
|
||||
|
||||
Returns:
|
||||
list: List of JSON strings with added "info_set" key,
|
||||
or error information
|
||||
"""
|
||||
try:
|
||||
# Step 1: Split messages into JSON lines
|
||||
json_lines = split_session_to_json_lines(session)
|
||||
|
||||
# Step 2: Process each JSON line with LLM info extraction
|
||||
processed_lines = []
|
||||
|
||||
for line in json_lines:
|
||||
data = json.loads(line)
|
||||
remaining_chat = data.get("remaining_chat", "")
|
||||
|
||||
# Retry loop for LLM calls
|
||||
info_set = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Call LLM info extraction
|
||||
# (using mock function for testing)
|
||||
llm_response = llm_info_extraction(
|
||||
remaining_chat,
|
||||
model_call_mode,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
info_set = parse_llm_output(llm_response)
|
||||
|
||||
if isinstance(info_set, list):
|
||||
break
|
||||
else:
|
||||
# If parsing failed, this is an error message
|
||||
print(f"Attempt {attempt + 1} failed: {info_set}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Attempt {attempt + 1} failed with exception: "
|
||||
f"{str(e)}",
|
||||
)
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1) # Shorter wait for testing
|
||||
|
||||
data["info_set"] = info_set
|
||||
processed_lines.append(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
return processed_lines
|
||||
|
||||
except Exception as e:
|
||||
return f"Pipeline error: {str(e)}"
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_file",
|
||||
type=str,
|
||||
default="tuner/learn_to_ask/data_raw/train_origin.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
type=str,
|
||||
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_call_mode",
|
||||
type=str,
|
||||
choices=["online_api", "local_vllm"],
|
||||
default="local_vllm",
|
||||
)
|
||||
parser.add_argument("--model_path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
print(
|
||||
process_jsonl_file(
|
||||
input_file=args.input_file,
|
||||
output_file=args.output_file,
|
||||
model_call_mode=args.model_call_mode,
|
||||
model_path=args.model_path,
|
||||
# Additional parameters for API calls
|
||||
),
|
||||
)
|
||||
65
tuner/learn_to_ask/data_prepare/2_build_dataset.py
Normal file
65
tuner/learn_to_ask/data_prepare/2_build_dataset.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: skip-file
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def process_message(json_obj: dict) -> tuple:
|
||||
info_set = json_obj.get("info_set")
|
||||
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
|
||||
if "user: " not in json_obj["remaining_chat"]:
|
||||
decision_str = "stop"
|
||||
else:
|
||||
decision_str = "continue"
|
||||
if not info_set_str and decision_str == "continue":
|
||||
if_keep = False
|
||||
else:
|
||||
if_keep = True
|
||||
return if_keep, info_set_str, decision_str
|
||||
|
||||
|
||||
def main(input_file_path: str, output_file_path: str) -> None:
|
||||
with open(input_file_path, "r", encoding="utf-8") as infile, open(
|
||||
output_file_path,
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as outfile:
|
||||
print("data processing started...")
|
||||
for line in infile:
|
||||
data = json.loads(line.strip())
|
||||
if_keep, info_set, decision = process_message(data)
|
||||
if not if_keep:
|
||||
continue
|
||||
|
||||
new_item = {
|
||||
"cid": data["cid"],
|
||||
"session_id": data["session_id"],
|
||||
"diagn": data["diagn"],
|
||||
"messages": data["messages"],
|
||||
"decision_truth": decision,
|
||||
"info_truth": info_set,
|
||||
}
|
||||
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
|
||||
print("job done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# The file generated by 1_info_extract_pipeline.py
|
||||
parser.add_argument(
|
||||
"--input_file",
|
||||
type=str,
|
||||
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
|
||||
)
|
||||
|
||||
# The final file for training or testing
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
type=str,
|
||||
default="tuner/learn_to_ask/data/train.jsonl",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.input_file, args.output_file)
|
||||
334
tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py
Normal file
334
tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py
Normal file
@@ -0,0 +1,334 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: skip-file
|
||||
"""
|
||||
This script is used to use VLLM to generate rollout samples from the converted
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from transformers import Any, AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def init_llm(model_path: str) -> tuple:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
device_count = torch.cuda.device_count()
|
||||
print(f"device_count={device_count}")
|
||||
if device_count < 1:
|
||||
raise RuntimeError("No GPU available for multi-card inference.")
|
||||
print(f"Loading model from: {model_path}")
|
||||
llm = LLM(model=model_path, tensor_parallel_size=device_count)
|
||||
print("Model loaded successfully!")
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
max_tokens=512,
|
||||
repetition_penalty=1.2,
|
||||
)
|
||||
return llm, tokenizer, sampling_params
|
||||
|
||||
|
||||
def rollout(
|
||||
llm: Any,
|
||||
tokenizer: Any,
|
||||
sampling_params: Any,
|
||||
input_file_path: str,
|
||||
output_file_path: str,
|
||||
rollout_repeat: int = 3,
|
||||
) -> None:
|
||||
import importlib
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"prompt",
|
||||
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
spec.loader.exec_module(module)
|
||||
rollout_prompt = module.rollout_prompt_med
|
||||
|
||||
with open(input_file_path, "r", encoding="utf-8") as lines:
|
||||
sample_list = [json.loads(line.strip()) for line in lines]
|
||||
print(f"loaded samples: {len(sample_list)}")
|
||||
|
||||
for index, sample in enumerate(sample_list):
|
||||
record = copy.deepcopy(sample)
|
||||
print(f"index: {index}, session_id: {sample['session_id']}")
|
||||
messages = [{"role": "system", "content": rollout_prompt}] + sample[
|
||||
"messages"
|
||||
]
|
||||
|
||||
# Some tokenizers (e.g., Qwen) support the `enable_thinking` argument,
|
||||
# but others do not. Try with the argument first, and fall back if
|
||||
# it is not accepted.
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
except TypeError:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
response_list = []
|
||||
for i in range(rollout_repeat):
|
||||
time_probe = time.perf_counter()
|
||||
outputs = llm.generate([prompt], sampling_params=sampling_params)
|
||||
print(f"time cost: {time.perf_counter() - time_probe}")
|
||||
for output in outputs:
|
||||
response = output.outputs[0].text
|
||||
response_list.append(response)
|
||||
print(f"rollout #{i}: {response}\n")
|
||||
record["rollouts"] = response_list
|
||||
|
||||
# append to output file
|
||||
with open(output_file_path, "a") as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def eval_sample(
|
||||
llm: Any,
|
||||
tokenizer: Any,
|
||||
sampling_params: Any,
|
||||
input_file_path: str,
|
||||
output_file_path: str,
|
||||
) -> None:
|
||||
import importlib
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"prompt",
|
||||
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
spec.loader.exec_module(module)
|
||||
grader_prompt = module.reward_prompt_med
|
||||
|
||||
print(f"input_file_path: {input_file_path}")
|
||||
print(f"output_file_path: {output_file_path}")
|
||||
|
||||
with open(input_file_path, "r", encoding="utf-8") as lines:
|
||||
sample_list = [json.loads(line.strip()) for line in lines]
|
||||
print(f"Total records: {len(sample_list)}")
|
||||
|
||||
def res_formatter(res_content: str) -> dict:
|
||||
pattern = r"<(\w+)>(.*?)</\1>"
|
||||
matches = re.findall(pattern, res_content)
|
||||
result = {}
|
||||
for tag_name, content in matches:
|
||||
result[tag_name] = content
|
||||
return result
|
||||
|
||||
def msg2str(msg_list: List) -> str:
|
||||
result_str = ""
|
||||
for msg in msg_list:
|
||||
if msg["role"] == "user":
|
||||
result_str += f"patient: {msg['content']}\n"
|
||||
if msg["role"] == "assistant":
|
||||
result_str += f"doctor: {msg['content']}\n"
|
||||
return result_str
|
||||
|
||||
for index, sample in enumerate(sample_list):
|
||||
print(f"index: {index}, cid: {sample['cid']}")
|
||||
action_truth = sample["decision_truth"]
|
||||
info_truth = sample["info_truth"] if sample["info_truth"] else "None."
|
||||
print(f"action_truth: {action_truth}, info_truth:{info_truth}")
|
||||
|
||||
sys_prompt = grader_prompt.format(info_truth)
|
||||
history = msg2str(sample["messages"])
|
||||
|
||||
sample["grades"] = []
|
||||
for rollout in sample["rollouts"]:
|
||||
time_probe = time.perf_counter()
|
||||
action_score, content_score, format_score, res_think = (
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
"NA",
|
||||
)
|
||||
if "<stop />" in rollout:
|
||||
action_rollout = "stop"
|
||||
else:
|
||||
action_rollout = "continue"
|
||||
if action_truth == action_rollout:
|
||||
action_score = 1
|
||||
if action_truth == "continue":
|
||||
user_content = history + f"doctor: {rollout}"
|
||||
messages = [
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
try:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for tokenizers that do not support the
|
||||
# Qwen-specific `enable_thinking` argument.
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
outputs = llm.generate(
|
||||
[prompt],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
for output in outputs:
|
||||
response = output.outputs[0].text
|
||||
print(f"Response: {response}\n")
|
||||
res_dict = res_formatter(response)
|
||||
try:
|
||||
format_score = float(res_dict.get("format_score", 0.0))
|
||||
content_score = float(
|
||||
res_dict.get("content_score", 0.0),
|
||||
)
|
||||
res_think = res_dict.get("think", "None")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
else:
|
||||
content_score = 1.0
|
||||
format_score = 1.0 if rollout == "<stop />" else 0.0
|
||||
else:
|
||||
action_score, format_score, content_score = 0, 0, 0
|
||||
grade_result = {
|
||||
"think": res_think,
|
||||
"action_score": action_score,
|
||||
"format_score": format_score,
|
||||
"content_score": content_score,
|
||||
}
|
||||
sample["grades"].append(grade_result)
|
||||
json_str = json.dumps(grade_result, ensure_ascii=False, indent=2)
|
||||
print(
|
||||
f"grade_result:{json_str}",
|
||||
)
|
||||
print(f"time_cost:{time.perf_counter() - time_probe}")
|
||||
# append sample to output file
|
||||
with open(output_file_path, "a") as f:
|
||||
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||||
print("\n======================\n")
|
||||
|
||||
|
||||
def compute_score(input_file_path: str) -> None:
|
||||
with open(input_file_path, "r", encoding="utf-8") as lines:
|
||||
sample_list = [json.loads(line.strip()) for line in lines]
|
||||
continue_count, continue_content_score, continue_content_full = 0, 0, 0
|
||||
continue_decision_score = 0
|
||||
stop_count, stop_decision_score = 0, 0
|
||||
total_reward, total_format = 0, 0
|
||||
(
|
||||
continue_count_correct,
|
||||
continue_content_score_correct,
|
||||
continue_content_full_correct,
|
||||
) = (0, 0, 0)
|
||||
for sample in sample_list:
|
||||
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
|
||||
if math.isnan(grade["content_score"]) or math.isnan(
|
||||
grade["format_score"],
|
||||
):
|
||||
continue
|
||||
if sample["decision_truth"] == "continue":
|
||||
continue_count += 1
|
||||
continue_content_score += grade["content_score"]
|
||||
continue_content_full += (
|
||||
1 if grade["content_score"] == 1 else 0
|
||||
)
|
||||
continue_decision_score += grade["action_score"]
|
||||
if "<stop />" not in rollout:
|
||||
continue_count_correct += 1
|
||||
continue_content_score_correct += grade["content_score"]
|
||||
continue_content_full_correct += (
|
||||
1 if grade["content_score"] == 1 else 0
|
||||
)
|
||||
|
||||
else:
|
||||
stop_count += 1
|
||||
stop_decision_score += grade["action_score"]
|
||||
total_reward += (
|
||||
grade["action_score"] * (1 + 2 * grade["content_score"])
|
||||
+ grade["format_score"]
|
||||
)
|
||||
total_format += grade["format_score"]
|
||||
|
||||
result = {
|
||||
"ave_continue_content": continue_content_score
|
||||
/ max(1, continue_count),
|
||||
"win_continue_content": continue_content_full / max(1, continue_count),
|
||||
"ave_continue_content if correct": continue_content_score_correct
|
||||
/ max(1, continue_count_correct),
|
||||
"win_continue_content if correct": continue_content_full_correct
|
||||
/ max(1, continue_count_correct),
|
||||
"ave_continue_decision": continue_decision_score
|
||||
/ max(1, continue_count),
|
||||
"ave_stop_decision": stop_decision_score / max(1, stop_count),
|
||||
"ave_total_decision": (continue_decision_score + stop_decision_score)
|
||||
/ max(1, continue_count + stop_count),
|
||||
"ave_total_format": total_format / max(1, continue_count + stop_count),
|
||||
"ave_total_reward": total_reward / max(1, continue_count + stop_count),
|
||||
}
|
||||
|
||||
print(f"total count: {continue_count + stop_count}")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--rollout_repeat", type=int, default=3)
|
||||
|
||||
# Ckpt for testing
|
||||
parser.add_argument("--eval_model_path", type=str, required=True)
|
||||
|
||||
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
|
||||
parser.add_argument("--grader_model_path", type=str, required=True)
|
||||
|
||||
# Your test sample path [input]
|
||||
parser.add_argument("--test_file_path", type=str, required=True)
|
||||
|
||||
# Rollout results given test samples [output]
|
||||
parser.add_argument("--rollout_file_path", type=str, required=True)
|
||||
|
||||
# Final output given rollout results [output]
|
||||
parser.add_argument("--eval_file_path", type=str, required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# rollout stage
|
||||
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
|
||||
rollout(
|
||||
llm,
|
||||
tokenizer,
|
||||
sampling_params,
|
||||
args.test_file_path,
|
||||
args.rollout_file_path,
|
||||
args.rollout_repeat,
|
||||
)
|
||||
del llm # clean up the memory after the inference
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache() # release gpu memory
|
||||
|
||||
# eval stage
|
||||
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
|
||||
eval_sample(
|
||||
llm2,
|
||||
tokenizer2,
|
||||
sampling_params2,
|
||||
args.rollout_file_path,
|
||||
args.eval_file_path,
|
||||
)
|
||||
compute_score(args.eval_file_path)
|
||||
174
tuner/learn_to_ask/data_prepare/llm_info_extraction.py
Normal file
174
tuner/learn_to_ask/data_prepare/llm_info_extraction.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: skip-file
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import openai
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
tokenizer = None
|
||||
llm = None
|
||||
|
||||
|
||||
def llm_info_extraction(
|
||||
remaining_chat: str,
|
||||
model_call_mode: str,
|
||||
**kwargs: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Extract information from remaining_chat using LLM.
|
||||
|
||||
Args:
|
||||
remaining_chat (str): The chat content to process
|
||||
model_call_mode (str): Either "online_api" or "local_vllm"
|
||||
**kwargs: Additional parameters for API calls
|
||||
|
||||
Returns:
|
||||
str: Response text from LLM or error information
|
||||
"""
|
||||
|
||||
# Create messages format with system and user roles
|
||||
system_message = """
|
||||
# Task:
|
||||
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field.
|
||||
# Requirements:
|
||||
- Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output.
|
||||
- Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown").
|
||||
- Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field.
|
||||
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
|
||||
- Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology.
|
||||
# Example output format:
|
||||
['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']
|
||||
""" # noqa: E501
|
||||
user_message = remaining_chat
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "user", "content": "```\n" + user_message + "\n```\n"},
|
||||
]
|
||||
|
||||
try:
|
||||
if model_call_mode == "online_api":
|
||||
# OpenAI-style API call
|
||||
return _call_online_api(messages, **kwargs)
|
||||
elif model_call_mode == "local_vllm":
|
||||
# Local vLLM call
|
||||
return _call_local_vllm(messages, **kwargs)
|
||||
else:
|
||||
return (
|
||||
f"Error: Invalid model_call_mode '{model_call_mode}'. "
|
||||
"Must be 'online_api' or 'local_vllm'."
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error occurred: {str(e)}"
|
||||
|
||||
|
||||
def _call_online_api(messages: List, **kwargs: dict) -> str:
|
||||
"""Handle OpenAI-style API calls"""
|
||||
# Extract API parameters from kwargs or use defaults
|
||||
api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY"))
|
||||
api_base = kwargs.get(
|
||||
"api_base",
|
||||
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
model = kwargs.get("model", "qwen2.5-72b-instruct")
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
max_tokens = kwargs.get("max_tokens", 500)
|
||||
|
||||
client = openai.OpenAI(api_key=api_key, base_url=api_base)
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
def _call_local_vllm(messages: List, **kwargs: dict) -> str:
|
||||
"""Handle local vLLM calls"""
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
model_path = kwargs.get("model_path")
|
||||
if not model_path:
|
||||
return "Error: model_path is required for local vLLM inference"
|
||||
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
max_tokens = kwargs.get("max_tokens", 512)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
repetition_penalty = kwargs.get("repetition_penalty", 1.1)
|
||||
|
||||
# GPU/CUDA related parameters for vLLM
|
||||
tensor_parallel_size = kwargs.get(
|
||||
"tensor_parallel_size",
|
||||
torch.cuda.device_count(),
|
||||
)
|
||||
gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
|
||||
enforce_eager = kwargs.get("enforce_eager", False)
|
||||
dtype = kwargs.get("dtype", "auto")
|
||||
max_model_len = kwargs.get("max_model_len", 4096)
|
||||
|
||||
# Initialize the LLM with the provided model path and GPU parameters
|
||||
global llm, tokenizer
|
||||
if llm is None:
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
|
||||
# Convert messages to a single prompt string
|
||||
if tokenizer is None:
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
except ImportError:
|
||||
return (
|
||||
"Error: vLLM library not installed. "
|
||||
"Please install it with 'pip install vllm'"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"Error in local vLLM inference: {str(e)}"
|
||||
|
||||
|
||||
def parse_llm_output(output_str: str) -> Union[List[str], str]:
|
||||
"""
|
||||
Convert the LLM info extraction output string to a list of strings.
|
||||
|
||||
Args:
|
||||
output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']"
|
||||
|
||||
Returns:
|
||||
list: List of strings if successful, error message string if failed
|
||||
""" # noqa: E501
|
||||
import ast
|
||||
|
||||
try:
|
||||
result = ast.literal_eval(output_str)
|
||||
if not isinstance(result, list):
|
||||
return f"Error: Expected a list, got {type(result)}"
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error parsing output: [{repr(output_str)}] error = {str(e)}"
|
||||
119
tuner/learn_to_ask/data_prepare/message_splitter.py
Normal file
119
tuner/learn_to_ask/data_prepare/message_splitter.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: skip-file
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def split_single_message_list(messages: List) -> List:
|
||||
"""
|
||||
Split a single message list into multiple rounds.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries with 'role' and 'content' keys
|
||||
|
||||
Returns:
|
||||
list: List of rounds, where each round contains messages and remaining chat
|
||||
""" # noqa: E501
|
||||
rounds = []
|
||||
round_number = 1
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
# Collect messages for this round
|
||||
round_messages = []
|
||||
|
||||
# Add messages until we reach a user message
|
||||
while i < len(messages) and messages[i].get("role") != "user":
|
||||
round_messages.append(messages[i])
|
||||
i += 1
|
||||
|
||||
# Add user message(s) - if there are consecutive user messages,
|
||||
# include all of them in this round
|
||||
while i < len(messages) and messages[i].get("role") == "user":
|
||||
round_messages.append(messages[i])
|
||||
i += 1
|
||||
|
||||
# The remaining messages (if any) form the remaining_chat
|
||||
remaining_messages = messages[i:]
|
||||
round_entry = {
|
||||
"round_number": round_number,
|
||||
"messages": round_messages,
|
||||
}
|
||||
|
||||
# Add remaining chat if there are remaining messages
|
||||
if remaining_messages:
|
||||
remaining_chat_parts = []
|
||||
for msg in remaining_messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
remaining_chat_parts.append(f"{role}: {content}")
|
||||
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
|
||||
else:
|
||||
round_entry["remaining_chat"] = ""
|
||||
|
||||
rounds.append(round_entry)
|
||||
round_number += 1
|
||||
|
||||
return rounds
|
||||
|
||||
|
||||
def split_session_to_json_lines(session: Dict) -> List[str]:
|
||||
"""
|
||||
Split a session dictionary into multiple rounds and convert to JSON lines.
|
||||
|
||||
Args:
|
||||
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
|
||||
- session_id (str): Session identifier
|
||||
- diagn (str): Diagnosis information
|
||||
- messages (list): List of message dictionaries with 'role' and 'content' keys
|
||||
|
||||
Returns:
|
||||
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
|
||||
""" # noqa: E501
|
||||
rounds = split_single_message_list(session["messages"])
|
||||
|
||||
json_lines = []
|
||||
for round_data in rounds:
|
||||
round_entry = {
|
||||
"cid": f"{session['session_id']}_{round_data['round_number']}",
|
||||
"session_id": session["session_id"],
|
||||
"diagn": session["diagn"],
|
||||
"messages": round_data["messages"],
|
||||
"remaining_chat": round_data["remaining_chat"],
|
||||
}
|
||||
|
||||
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
|
||||
|
||||
return json_lines
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Example of splitting a single message list
|
||||
example_messages = [
|
||||
{"role": "assistant", "content": "Hello, how can I help you today?"},
|
||||
{"role": "user", "content": "I've been having headaches lately."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "How long have you been experiencing these headaches?",
|
||||
},
|
||||
{"role": "user", "content": "For about a week now."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I see. Have you taken any medication for them?",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Yes, I've tried some over-the-counter pain relievers.",
|
||||
},
|
||||
]
|
||||
|
||||
example_session = {
|
||||
"session_id": "session_1",
|
||||
"diagn": "migraine",
|
||||
"messages": example_messages,
|
||||
}
|
||||
json_lines = split_session_to_json_lines(example_session)
|
||||
print("JSON lines output:")
|
||||
for i, line in enumerate(json_lines):
|
||||
print(f"Line {i + 1}: {line}")
|
||||
Reference in New Issue
Block a user