""" helper util to calculate dataset lengths """ import numpy as np def get_dataset_lengths(dataset): if "length" in dataset.data.column_names: lengths = np.array(dataset.data.column("length")) elif "position_ids" in dataset.data.column_names: position_ids = dataset.data.column("position_ids") lengths = np.array([x[-1] + 1 for x in position_ids]) else: input_ids = dataset.data.column("input_ids") lengths = np.vectorize(len)(np.array(input_ids, dtype=object)) return lengths return lengths