"""
helper util to calculate dataset lengths
"""
import numpy as np


def get_dataset_lengths(dataset):
    if "length" in dataset.data.column_names:
        lengths = np.array(dataset.data.column("length"))
    else:
        lengths = (
            dataset.data.column("position_ids")
            .to_pandas()
            .apply(lambda x: x[-1] + 1)
            .values
        )
    return lengths