|
from typing import Union |
|
|
|
import numpy as np |
|
from transformers.utils import TensorType |
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
|
|
|
class PadAndSortCollator: |
|
def __init__(self, processor, return_tensors: Union[str, TensorType] = "pt"): |
|
self.processor = processor |
|
self.return_tensors = return_tensors |
|
|
|
def __call__(self, batch): |
|
""" |
|
expect batch with `return_tensors=None` from processor |
|
batch: input_ids, length(optional), mel_specgram, mel_specgram_length(optional) |
|
""" |
|
text_batch = {} |
|
text_batch["input_ids"] = [x["input_ids"] for x in batch] |
|
if "length" in batch[0]: |
|
text_batch["length"] = [x["length"] for x in batch] |
|
else: |
|
text_batch["length"] = [len(x["input_ids"]) for x in batch] |
|
|
|
audio_batch = {} |
|
|
|
audio_batch["mel_specgram"] = [ |
|
x["mel_specgram"][0].transpose(1, 0) for x in batch |
|
] |
|
if "mel_specgram_length" in batch[0]: |
|
audio_batch["mel_specgram_length"] = [ |
|
x["mel_specgram_length"] for x in batch |
|
] |
|
else: |
|
audio_batch["mel_specgram_length"] = [ |
|
x["mel_specgram"][0].shape[1] for x in batch |
|
] |
|
|
|
text_batch = self.processor.tokenizer.pad( |
|
text_batch, |
|
padding=True, |
|
return_tensors="np", |
|
return_attention_mask=False, |
|
) |
|
|
|
audio_batch = self.processor.feature_extractor.pad( |
|
audio_batch, |
|
padding=True, |
|
return_tensors="np", |
|
return_attention_mask=True, |
|
) |
|
audio_batch["mel_specgram"] = audio_batch["mel_specgram"].transpose(0, 2, 1) |
|
|
|
attention_mask = audio_batch.pop("attention_mask") |
|
gate_padded = 1 - attention_mask |
|
gate_padded = np.roll(gate_padded, -1, axis=1) |
|
gate_padded[:, -1] = 1 |
|
gate_padded = gate_padded.astype(np.float32) |
|
|
|
output = {**text_batch, **audio_batch, "gate_padded": gate_padded} |
|
|
|
|
|
sort_idx = np.argsort(output["length"])[::-1] |
|
|
|
for key, value in output.items(): |
|
output[key] = value[sort_idx] |
|
|
|
return BatchFeature(output, tensor_type=self.return_tensors) |
|
|