from typing import Union import numpy as np from transformers.utils import TensorType from transformers.feature_extraction_utils import BatchFeature class PadAndSortCollator: def __init__(self, processor, return_tensors: Union[str, TensorType] = "pt"): self.processor = processor self.return_tensors = return_tensors def __call__(self, batch): """ expect batch with `return_tensors=None` from processor batch: input_ids, length(optional), mel_specgram, mel_specgram_length(optional) """ text_batch = {} text_batch["input_ids"] = [x["input_ids"] for x in batch] if "length" in batch[0]: text_batch["length"] = [x["length"] for x in batch] else: text_batch["length"] = [len(x["input_ids"]) for x in batch] audio_batch = {} # transpose mel_specgram for padding audio_batch["mel_specgram"] = [ x["mel_specgram"][0].transpose(1, 0) for x in batch ] if "mel_specgram_length" in batch[0]: audio_batch["mel_specgram_length"] = [ x["mel_specgram_length"] for x in batch ] else: audio_batch["mel_specgram_length"] = [ x["mel_specgram"][0].shape[1] for x in batch ] text_batch = self.processor.tokenizer.pad( text_batch, padding=True, return_tensors="np", return_attention_mask=False, ) audio_batch = self.processor.feature_extractor.pad( audio_batch, padding=True, return_tensors="np", return_attention_mask=True, ) audio_batch["mel_specgram"] = audio_batch["mel_specgram"].transpose(0, 2, 1) attention_mask = audio_batch.pop("attention_mask") gate_padded = 1 - attention_mask gate_padded = np.roll(gate_padded, -1, axis=1) gate_padded[:, -1] = 1 gate_padded = gate_padded.astype(np.float32) output = {**text_batch, **audio_batch, "gate_padded": gate_padded} # sort by text length sort_idx = np.argsort(output["length"])[::-1] for key, value in output.items(): output[key] = value[sort_idx] return BatchFeature(output, tensor_type=self.return_tensors)