120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
# pylint: skip-file
|
|
import json
|
|
from typing import Dict, List
|
|
|
|
|
|
def split_single_message_list(messages: List) -> List:
|
|
"""
|
|
Split a single message list into multiple rounds.
|
|
|
|
Args:
|
|
messages (list): List of message dictionaries with 'role' and 'content' keys
|
|
|
|
Returns:
|
|
list: List of rounds, where each round contains messages and remaining chat
|
|
""" # noqa: E501
|
|
rounds = []
|
|
round_number = 1
|
|
i = 0
|
|
|
|
while i < len(messages):
|
|
# Collect messages for this round
|
|
round_messages = []
|
|
|
|
# Add messages until we reach a user message
|
|
while i < len(messages) and messages[i].get("role") != "user":
|
|
round_messages.append(messages[i])
|
|
i += 1
|
|
|
|
# Add user message(s) - if there are consecutive user messages,
|
|
# include all of them in this round
|
|
while i < len(messages) and messages[i].get("role") == "user":
|
|
round_messages.append(messages[i])
|
|
i += 1
|
|
|
|
# The remaining messages (if any) form the remaining_chat
|
|
remaining_messages = messages[i:]
|
|
round_entry = {
|
|
"round_number": round_number,
|
|
"messages": round_messages,
|
|
}
|
|
|
|
# Add remaining chat if there are remaining messages
|
|
if remaining_messages:
|
|
remaining_chat_parts = []
|
|
for msg in remaining_messages:
|
|
role = msg.get("role", "")
|
|
content = msg.get("content", "")
|
|
remaining_chat_parts.append(f"{role}: {content}")
|
|
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
|
|
else:
|
|
round_entry["remaining_chat"] = ""
|
|
|
|
rounds.append(round_entry)
|
|
round_number += 1
|
|
|
|
return rounds
|
|
|
|
|
|
def split_session_to_json_lines(session: Dict) -> List[str]:
|
|
"""
|
|
Split a session dictionary into multiple rounds and convert to JSON lines.
|
|
|
|
Args:
|
|
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
|
|
- session_id (str): Session identifier
|
|
- diagn (str): Diagnosis information
|
|
- messages (list): List of message dictionaries with 'role' and 'content' keys
|
|
|
|
Returns:
|
|
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
|
|
""" # noqa: E501
|
|
rounds = split_single_message_list(session["messages"])
|
|
|
|
json_lines = []
|
|
for round_data in rounds:
|
|
round_entry = {
|
|
"cid": f"{session['session_id']}_{round_data['round_number']}",
|
|
"session_id": session["session_id"],
|
|
"diagn": session["diagn"],
|
|
"messages": round_data["messages"],
|
|
"remaining_chat": round_data["remaining_chat"],
|
|
}
|
|
|
|
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
|
|
|
|
return json_lines
|
|
|
|
|
|
# Example usage:
|
|
if __name__ == "__main__":
|
|
# Example of splitting a single message list
|
|
example_messages = [
|
|
{"role": "assistant", "content": "Hello, how can I help you today?"},
|
|
{"role": "user", "content": "I've been having headaches lately."},
|
|
{
|
|
"role": "assistant",
|
|
"content": "How long have you been experiencing these headaches?",
|
|
},
|
|
{"role": "user", "content": "For about a week now."},
|
|
{
|
|
"role": "assistant",
|
|
"content": "I see. Have you taken any medication for them?",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Yes, I've tried some over-the-counter pain relievers.",
|
|
},
|
|
]
|
|
|
|
example_session = {
|
|
"session_id": "session_1",
|
|
"diagn": "migraine",
|
|
"messages": example_messages,
|
|
}
|
|
json_lines = split_session_to_json_lines(example_session)
|
|
print("JSON lines output:")
|
|
for i, line in enumerate(json_lines):
|
|
print(f"Line {i + 1}: {line}")
|