NEOX / tools /datasets /merge_datasets.py
akswelh's picture
Upload 251 files
d90b3a8 verified
import os
import sys
import json
import argparse
sys.path.append(
os.path.abspath(
os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)
)
)
from megatron.data import indexed_dataset
def main(args):
prefixes = set()
for basename in os.listdir(args.input):
prefix, ext = os.path.splitext(basename)
if prefix in prefixes:
continue
if not os.path.isfile(os.path.join(args.input, basename)):
continue
ext_pair = ".bin" if ext == ".idx" else ".idx"
assert os.path.isfile(
os.path.join(args.input, prefix) + ext_pair
), f"ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}"
prefixes.add(prefix)
builder = None
for prefix in sorted(prefixes):
if builder is None:
dataset = indexed_dataset.make_dataset(
os.path.join(args.input, prefix), "infer"
)
if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
builder = indexed_dataset.MMapIndexedDatasetBuilder(
args.output_prefix + ".bin", dtype=dataset._index.dtype
)
else:
builder = indexed_dataset.IndexedDatasetBuilder(
args.output_prefix + ".bin"
)
del dataset
builder.merge_file_(os.path.join(args.input, prefix))
builder.finalize(args.output_prefix + ".idx")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title="input data")
group.add_argument(
"--input",
type=str,
required=True,
help="Path to directory containing all document files to merge",
)
group = parser.add_argument_group(title="output data")
group.add_argument(
"--output-prefix",
type=str,
required=True,
help="Path to binary output file without suffix",
)
args = parser.parse_args()
assert os.path.isdir(
args.input
), f"ERROR: {args.input} is not a directory or does not exist"
assert os.path.isdir(
os.path.dirname(args.output_prefix)
), f"ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist"
main(args)