Bingsu's picture
Upload 3 files
d07276d
from typing import Union
import numpy as np
from transformers.utils import TensorType
from transformers.feature_extraction_utils import BatchFeature
class PadAndSortCollator:
def __init__(self, processor, return_tensors: Union[str, TensorType] = "pt"):
self.processor = processor
self.return_tensors = return_tensors
def __call__(self, batch):
"""
expect batch with `return_tensors=None` from processor
batch: input_ids, length(optional), mel_specgram, mel_specgram_length(optional)
"""
text_batch = {}
text_batch["input_ids"] = [x["input_ids"] for x in batch]
if "length" in batch[0]:
text_batch["length"] = [x["length"] for x in batch]
else:
text_batch["length"] = [len(x["input_ids"]) for x in batch]
audio_batch = {}
# transpose mel_specgram for padding
audio_batch["mel_specgram"] = [
x["mel_specgram"][0].transpose(1, 0) for x in batch
]
if "mel_specgram_length" in batch[0]:
audio_batch["mel_specgram_length"] = [
x["mel_specgram_length"] for x in batch
]
else:
audio_batch["mel_specgram_length"] = [
x["mel_specgram"][0].shape[1] for x in batch
]
text_batch = self.processor.tokenizer.pad(
text_batch,
padding=True,
return_tensors="np",
return_attention_mask=False,
)
audio_batch = self.processor.feature_extractor.pad(
audio_batch,
padding=True,
return_tensors="np",
return_attention_mask=True,
)
audio_batch["mel_specgram"] = audio_batch["mel_specgram"].transpose(0, 2, 1)
attention_mask = audio_batch.pop("attention_mask")
gate_padded = 1 - attention_mask
gate_padded = np.roll(gate_padded, -1, axis=1)
gate_padded[:, -1] = 1
gate_padded = gate_padded.astype(np.float32)
output = {**text_batch, **audio_batch, "gate_padded": gate_padded}
# sort by text length
sort_idx = np.argsort(output["length"])[::-1]
for key, value in output.items():
output[key] = value[sort_idx]
return BatchFeature(output, tensor_type=self.return_tensors)