66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
# pylint: skip-file
|
|
import argparse
|
|
import json
|
|
|
|
|
|
def process_message(json_obj: dict) -> tuple:
|
|
info_set = json_obj.get("info_set")
|
|
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
|
|
if "user: " not in json_obj["remaining_chat"]:
|
|
decision_str = "stop"
|
|
else:
|
|
decision_str = "continue"
|
|
if not info_set_str and decision_str == "continue":
|
|
if_keep = False
|
|
else:
|
|
if_keep = True
|
|
return if_keep, info_set_str, decision_str
|
|
|
|
|
|
def main(input_file_path: str, output_file_path: str) -> None:
|
|
with open(input_file_path, "r", encoding="utf-8") as infile, open(
|
|
output_file_path,
|
|
"w",
|
|
encoding="utf-8",
|
|
) as outfile:
|
|
print("data processing started...")
|
|
for line in infile:
|
|
data = json.loads(line.strip())
|
|
if_keep, info_set, decision = process_message(data)
|
|
if not if_keep:
|
|
continue
|
|
|
|
new_item = {
|
|
"cid": data["cid"],
|
|
"session_id": data["session_id"],
|
|
"diagn": data["diagn"],
|
|
"messages": data["messages"],
|
|
"decision_truth": decision,
|
|
"info_truth": info_set,
|
|
}
|
|
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
|
|
print("job done!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# The file generated by 1_info_extract_pipeline.py
|
|
parser.add_argument(
|
|
"--input_file",
|
|
type=str,
|
|
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
|
|
)
|
|
|
|
# The final file for training or testing
|
|
parser.add_argument(
|
|
"--output_file",
|
|
type=str,
|
|
default="tuner/learn_to_ask/data/train.jsonl",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args.input_file, args.output_file)
|