Files
evotraders/tuner/learn_to_ask/data_prepare/2_build_dataset.py
2026-01-16 19:24:46 +08:00

66 lines
1.8 KiB
Python

# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
def process_message(json_obj: dict) -> tuple:
info_set = json_obj.get("info_set")
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
if "user: " not in json_obj["remaining_chat"]:
decision_str = "stop"
else:
decision_str = "continue"
if not info_set_str and decision_str == "continue":
if_keep = False
else:
if_keep = True
return if_keep, info_set_str, decision_str
def main(input_file_path: str, output_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as infile, open(
output_file_path,
"w",
encoding="utf-8",
) as outfile:
print("data processing started...")
for line in infile:
data = json.loads(line.strip())
if_keep, info_set, decision = process_message(data)
if not if_keep:
continue
new_item = {
"cid": data["cid"],
"session_id": data["session_id"],
"diagn": data["diagn"],
"messages": data["messages"],
"decision_truth": decision,
"info_truth": info_set,
}
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
print("job done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# The file generated by 1_info_extract_pipeline.py
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
# The final file for training or testing
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data/train.jsonl",
)
args = parser.parse_args()
main(args.input_file, args.output_file)