Files
evotraders/tuner/werewolves/prepare_data.py
2026-01-16 17:25:49 +08:00

72 lines
1.9 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# flake8: noqa: E501
"""Prepare dataset for werewolf game training.
This script generates a simple dataset consisting of random seeds for role shuffling.
Each seed creates a different initial role assignment, ensuring diverse training scenarios.
"""
import json
import argparse
from pathlib import Path
def prepare_dataset(
output_dir: str,
num_seeds: int = 300,
split: str = "train",
) -> None:
"""Prepare the werewolf game training dataset.
Args:
output_dir (str): Directory to save the dataset.
num_seeds (int): Number of seeds to generate. Default: 300.
split (str): Dataset split name (e.g., 'train', 'eval'). Default: 'train'.
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
output_file = output_path / f"{split}.jsonl"
print(f"Generating {num_seeds} seeds for {split} split...")
with open(output_file, "w", encoding="utf-8") as f:
for seed in range(num_seeds):
data = {"seed": seed}
f.write(json.dumps(data) + "\n")
print(f"Dataset saved to: {output_file}")
print(f"Total samples: {num_seeds}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Prepare dataset for werewolf game training",
)
parser.add_argument(
"--output_dir",
type=str,
default="data",
help="Directory to save the dataset (default: data)",
)
parser.add_argument(
"--num_seeds",
type=int,
default=300,
help="Number of seeds to generate (default: 300)",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="Dataset split name (default: train)",
)
args = parser.parse_args()
prepare_dataset(
output_dir=args.output_dir,
num_seeds=args.num_seeds,
split=args.split,
)