Files
evotraders/tuner/learn_to_ask/data_prepare/message_splitter.py
2026-01-16 19:24:46 +08:00

120 lines
3.9 KiB
Python

# -*- coding: utf-8 -*-
# pylint: skip-file
import json
from typing import Dict, List
def split_single_message_list(messages: List) -> List:
"""
Split a single message list into multiple rounds.
Args:
messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of rounds, where each round contains messages and remaining chat
""" # noqa: E501
rounds = []
round_number = 1
i = 0
while i < len(messages):
# Collect messages for this round
round_messages = []
# Add messages until we reach a user message
while i < len(messages) and messages[i].get("role") != "user":
round_messages.append(messages[i])
i += 1
# Add user message(s) - if there are consecutive user messages,
# include all of them in this round
while i < len(messages) and messages[i].get("role") == "user":
round_messages.append(messages[i])
i += 1
# The remaining messages (if any) form the remaining_chat
remaining_messages = messages[i:]
round_entry = {
"round_number": round_number,
"messages": round_messages,
}
# Add remaining chat if there are remaining messages
if remaining_messages:
remaining_chat_parts = []
for msg in remaining_messages:
role = msg.get("role", "")
content = msg.get("content", "")
remaining_chat_parts.append(f"{role}: {content}")
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
else:
round_entry["remaining_chat"] = ""
rounds.append(round_entry)
round_number += 1
return rounds
def split_session_to_json_lines(session: Dict) -> List[str]:
"""
Split a session dictionary into multiple rounds and convert to JSON lines.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
- session_id (str): Session identifier
- diagn (str): Diagnosis information
- messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
""" # noqa: E501
rounds = split_single_message_list(session["messages"])
json_lines = []
for round_data in rounds:
round_entry = {
"cid": f"{session['session_id']}_{round_data['round_number']}",
"session_id": session["session_id"],
"diagn": session["diagn"],
"messages": round_data["messages"],
"remaining_chat": round_data["remaining_chat"],
}
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
return json_lines
# Example usage:
if __name__ == "__main__":
# Example of splitting a single message list
example_messages = [
{"role": "assistant", "content": "Hello, how can I help you today?"},
{"role": "user", "content": "I've been having headaches lately."},
{
"role": "assistant",
"content": "How long have you been experiencing these headaches?",
},
{"role": "user", "content": "For about a week now."},
{
"role": "assistant",
"content": "I see. Have you taken any medication for them?",
},
{
"role": "user",
"content": "Yes, I've tried some over-the-counter pain relievers.",
},
]
example_session = {
"session_id": "session_1",
"diagn": "migraine",
"messages": example_messages,
}
json_lines = split_session_to_json_lines(example_session)
print("JSON lines output:")
for i, line in enumerate(json_lines):
print(f"Line {i + 1}: {line}")